As AI models grow larger and datasets become more massive, training on a single machine becomes impractical or impossible. Distributed training enables training across multiple GPUs and machines, dramatically reducing training time and enabling models that wouldn’t fit on a single device. This comprehensive guide explores the principles, strategies, and practical implementation of distributed AI training.
Why Distributed Training?
The Scale Challenge
Modern AI training faces unprecedented scale:
- GPT-3: 175 billion parameters, trained on 300 billion tokens
- DALL-E 2: Trained on 650 million image-text pairs
- ImageNet training: Days on single GPU → hours with distributed
Single-machine limitations:
- GPU memory limits model size
- Training time becomes prohibitively long
- Some models simply cannot fit on one device
Distributed Training Benefits
- Speed: Linear or near-linear speedup with more GPUs
- Model size: Partition models across devices
- Batch size: Larger effective batches improve convergence
- Cost efficiency: Finish training faster, pay less for cloud resources
Parallelism Strategies
Data Parallelism
The most common approach: replicate the model on each device, split data between them.
“python
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed(rank, world_size):
"""Initialize distributed training."""
dist.init_process_group(
backend="nccl", # Use NCCL for GPU
init_method="env://",
world_size=world_size,
rank=rank
)
torch.cuda.set_device(rank)
def train_ddp(rank, world_size, model, dataset, epochs):
setup_distributed(rank, world_size)
# Wrap model with DDP
model = model.to(rank)
model = DDP(model, device_ids=[rank])
# Distributed sampler ensures each GPU gets different data
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=32,
sampler=sampler
)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
sampler.set_epoch(epoch) # Shuffle differently each epoch
for batch in loader:
inputs, targets = batch
inputs = inputs.to(rank)
targets = targets.to(rank)
optimizer.zero_grad()
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)
loss.backward()
# DDP automatically synchronizes gradients
optimizer.step()
dist.destroy_process_group()
# Launch training
if __name__ == "__main__":
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(
train_ddp,
args=(world_size, model, dataset, 10),
nprocs=world_size
)
`
Gradient Synchronization
How DDP synchronizes gradients across devices:
`python
class ManualGradientSync:
"""Understanding gradient synchronization."""
def __init__(self, model, world_size):
self.model = model
self.world_size = world_size
def sync_gradients(self):
"""All-reduce gradients across all processes."""
for param in self.model.parameters():
if param.grad is not None:
# Sum gradients across all processes
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
# Average
param.grad.data /= self.world_size
def training_step(self, batch, optimizer):
optimizer.zero_grad()
# Forward and backward (local)
loss = self.compute_loss(batch)
loss.backward()
# Synchronize gradients
self.sync_gradients()
# Update (all processes now have same gradients)
optimizer.step()
`
Model Parallelism
Split the model across devices when it's too large to fit on one:
`python
class ModelParallel(nn.Module):
"""Simple model parallelism: different layers on different GPUs."""
def __init__(self):
super().__init__()
# First half on GPU 0
self.layer1 = nn.Linear(1024, 2048).to('cuda:0')
self.layer2 = nn.Linear(2048, 2048).to('cuda:0')
# Second half on GPU 1
self.layer3 = nn.Linear(2048, 2048).to('cuda:1')
self.layer4 = nn.Linear(2048, 1000).to('cuda:1')
def forward(self, x):
x = x.to('cuda:0')
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
# Transfer to GPU 1
x = x.to('cuda:1')
x = F.relu(self.layer3(x))
x = self.layer4(x)
return x
`
Pipeline Parallelism
Overlap computation and communication with micro-batches:
`python
class PipelineParallel:
"""GPipe-style pipeline parallelism."""
def __init__(self, stages, num_microbatches=4):
self.stages = stages # List of model segments
self.num_microbatches = num_microbatches
def forward(self, batch):
# Split batch into micro-batches
micro_batches = batch.chunk(self.num_microbatches)
# Forward pass with pipelining
outputs = []
for stage_idx, stage in enumerate(self.stages):
stage_outputs = []
for mb_idx, micro_batch in enumerate(micro_batches):
if stage_idx == 0:
x = micro_batch.to(f'cuda:{stage_idx}')
else:
# Get output from previous stage
x = prev_outputs[mb_idx].to(f'cuda:{stage_idx}')
out = stage(x)
stage_outputs.append(out)
prev_outputs = stage_outputs
# Concatenate outputs
return torch.cat(prev_outputs, dim=0)
def train_step(self, batch, targets, optimizer):
# Forward micro-batches through pipeline
outputs = self.forward(batch)
# Backward (gradients flow back through pipeline)
loss = F.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
class GPipeScheduler:
"""Implements GPipe scheduling with proper gradient accumulation."""
def __init__(self, stages, chunks=4):
self.stages = stages
self.chunks = chunks
self.num_stages = len(stages)
def forward_backward(self, batch, targets):
micro_batches = batch.chunk(self.chunks)
micro_targets = targets.chunk(self.chunks)
# Storage for activations and gradients
activations = [[None] * self.chunks for _ in range(self.num_stages)]
# 1. Forward pass all micro-batches
for mb_idx in range(self.chunks):
x = micro_batches[mb_idx]
for stage_idx, stage in enumerate(self.stages):
x = x.to(f'cuda:{stage_idx}')
x = stage(x)
activations[stage_idx][mb_idx] = x
# 2. Backward pass all micro-batches
total_loss = 0
for mb_idx in range(self.chunks - 1, -1, -1):
output = activations[-1][mb_idx]
loss = F.cross_entropy(output, micro_targets[mb_idx])
loss.backward()
total_loss += loss.item()
return total_loss / self.chunks
`
Tensor Parallelism
Split individual layers across devices:
`python
class TensorParallelLinear(nn.Module):
"""Split linear layer across GPUs (column-wise)."""
def __init__(self, in_features, out_features, world_size):
super().__init__()
# Each GPU handles a portion of output features
self.out_features_per_gpu = out_features // world_size
self.weight = nn.Parameter(
torch.randn(in_features, self.out_features_per_gpu)
)
self.bias = nn.Parameter(
torch.zeros(self.out_features_per_gpu)
)
self.world_size = world_size
def forward(self, x):
# Local computation
local_output = F.linear(x, self.weight.t(), self.bias)
# All-gather outputs from all GPUs
output_list = [torch.zeros_like(local_output)
for _ in range(self.world_size)]
dist.all_gather(output_list, local_output)
# Concatenate
return torch.cat(output_list, dim=-1)
class TensorParallelAttention(nn.Module):
"""Tensor-parallel multi-head attention (Megatron-style)."""
def __init__(self, hidden_size, num_heads, world_size):
super().__init__()
self.heads_per_gpu = num_heads // world_size
self.hidden_per_head = hidden_size // num_heads
# Q, K, V projections (column parallel)
self.qkv = TensorParallelLinear(
hidden_size,
3 * hidden_size // world_size,
1 # Already split
)
# Output projection (row parallel)
self.output = nn.Linear(
hidden_size // world_size,
hidden_size
)
def forward(self, x):
# Compute Q, K, V locally
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
# Local attention
attn_output = self.attention(q, k, v)
# Output projection (followed by all-reduce)
output = self.output(attn_output)
dist.all_reduce(output)
return output
`
Fully Sharded Data Parallel (FSDP)
Combine data parallelism with parameter sharding:
`python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
def train_with_fsdp(rank, world_size):
setup_distributed(rank, world_size)
# Create model
model = LargeTransformer()
# Wrap with FSDP
# Parameters are sharded across GPUs
fsdp_model = FSDP(
model,
auto_wrap_policy=size_based_auto_wrap_policy,
cpu_offload=None, # Can offload to CPU for memory
sharding_strategy=ShardingStrategy.FULL_SHARD
)
optimizer = torch.optim.Adam(fsdp_model.parameters())
for epoch in range(epochs):
for batch in dataloader:
optimizer.zero_grad()
# FSDP gathers parameters before forward
outputs = fsdp_model(batch)
loss = compute_loss(outputs)
# Gradients are reduced and sharded
loss.backward()
optimizer.step()
`
Optimization for Large Batches
Learning Rate Scaling
`python
class LearningRateScaler:
"""Scale learning rate for distributed training."""
@staticmethod
def linear_scaling(base_lr, base_batch_size, actual_batch_size):
"""Linear scaling rule."""
return base_lr * (actual_batch_size / base_batch_size)
@staticmethod
def sqrt_scaling(base_lr, base_batch_size, actual_batch_size):
"""Square root scaling (more conservative)."""
return base_lr * math.sqrt(actual_batch_size / base_batch_size)
class GradualWarmup:
"""Gradually increase learning rate at the start of training."""
def __init__(self, optimizer, warmup_steps, target_lr):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.target_lr = target_lr
self.step_count = 0
def step(self):
self.step_count += 1
if self.step_count <= self.warmup_steps:
lr = self.target_lr * (self.step_count / self.warmup_steps)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
`
LARS and LAMB Optimizers
`python
class LARS(torch.optim.Optimizer):
"""Layer-wise Adaptive Rate Scaling for large batch training."""
def __init__(self, params, lr, weight_decay=0, momentum=0.9,
trust_coefficient=0.001):
defaults = dict(
lr=lr, weight_decay=weight_decay,
momentum=momentum, trust_coefficient=trust_coefficient
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
# Weight decay
if group['weight_decay'] != 0:
grad = grad.add(p, alpha=group['weight_decay'])
# LARS scaling
param_norm = p.norm()
grad_norm = grad.norm()
if param_norm > 0 and grad_norm > 0:
local_lr = group['trust_coefficient'] * param_norm / grad_norm
else:
local_lr = 1.0
# Momentum
if 'momentum_buffer' not in self.state[p]:
self.state[p]['momentum_buffer'] = torch.zeros_like(p)
buf = self.state[p]['momentum_buffer']
buf.mul_(group['momentum']).add_(grad, alpha=local_lr)
# Update
p.add_(buf, alpha=-group['lr'])
class LAMB(torch.optim.Optimizer):
"""Layer-wise Adaptive Moments for Batch training."""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Adam update
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias correction
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Adam direction
adam_step = (exp_avg / bias_correction1) / (
(exp_avg_sq / bias_correction2).sqrt() + group['eps']
)
# Weight decay
if group['weight_decay'] != 0:
adam_step.add_(p, alpha=group['weight_decay'])
# LAMB trust ratio
weight_norm = p.norm()
adam_norm = adam_step.norm()
if weight_norm > 0 and adam_norm > 0:
trust_ratio = weight_norm / adam_norm
else:
trust_ratio = 1.0
# Update
p.add_(adam_step, alpha=-group['lr'] * trust_ratio)
`
Gradient Accumulation
`python
class GradientAccumulator:
"""Accumulate gradients over multiple mini-batches."""
def __init__(self, model, optimizer, accumulation_steps):
self.model = model
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.step_count = 0
def training_step(self, batch):
# Scale loss by accumulation steps
loss = self.compute_loss(batch) / self.accumulation_steps
loss.backward()
self.step_count += 1
if self.step_count % self.accumulation_steps == 0:
# Clip gradients
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=1.0
)
# Update weights
self.optimizer.step()
self.optimizer.zero_grad()
return True # Weights updated
return False # Gradients accumulated
`
Communication Optimization
Gradient Compression
`python
class GradientCompressor:
"""Compress gradients to reduce communication."""
@staticmethod
def top_k(tensor, ratio=0.01):
"""Keep only top-k% of gradients."""
k = max(1, int(tensor.numel() * ratio))
values, indices = tensor.abs().flatten().topk(k)
compressed = torch.zeros_like(tensor).flatten()
compressed[indices] = tensor.flatten()[indices]
return compressed.view_as(tensor)
@staticmethod
def random_k(tensor, ratio=0.01):
"""Keep random k% of gradients."""
mask = torch.rand_like(tensor) < ratio
return tensor * mask / ratio # Scale to maintain expectation
@staticmethod
def quantize(tensor, bits=8):
"""Quantize gradients to lower precision."""
min_val = tensor.min()
max_val = tensor.max()
scale = (max_val - min_val) / (2**bits - 1)
quantized = ((tensor - min_val) / scale).round().to(torch.uint8)
return quantized, scale, min_val
class PowerSGD:
"""Low-rank gradient compression."""
def __init__(self, rank=4):
self.rank = rank
self.memory = {}
def compress(self, gradient, name):
if gradient.dim() < 2:
return gradient # Can't compress 1D tensors
# Reshape to 2D
original_shape = gradient.shape
gradient = gradient.view(gradient.shape[0], -1)
m, n = gradient.shape
# Initialize or retrieve P, Q matrices
if name not in self.memory:
self.memory[name] = {
'P': torch.randn(m, self.rank, device=gradient.device),
'Q': torch.randn(n, self.rank, device=gradient.device)
}
P = self.memory[name]['P']
Q = self.memory[name]['Q']
# Power iteration
Q, _ = torch.qr(gradient.t() @ P)
P, _ = torch.qr(gradient @ Q)
self.memory[name]['P'] = P
self.memory[name]['Q'] = Q
# Low-rank approximation
approx = P @ (P.t() @ gradient @ Q) @ Q.t()
return approx.view(original_shape)
`
Overlapping Communication and Computation
`python
class OverlappedDDP:
"""Overlap gradient communication with backward computation."""
def __init__(self, model, world_size):
self.model = model
self.world_size = world_size
# Register hooks for overlapped communication
self._register_hooks()
def _register_hooks(self):
"""Register backward hooks for each parameter."""
self.handles = []
for param in self.model.parameters():
if param.requires_grad:
handle = param.register_hook(
self._make_hook(param)
)
self.handles.append(handle)
def _make_hook(self, param):
def hook(grad):
# Start async all-reduce
handle = dist.all_reduce(
grad,
op=dist.ReduceOp.SUM,
async_op=True
)
# Store handle for later synchronization
self.async_handles.append((param, handle))
return grad
return hook
def synchronize(self):
"""Wait for all gradient communications to complete."""
for param, handle in self.async_handles:
handle.wait()
param.grad /= self.world_size
self.async_handles = []
`
Multi-Node Training
Launch Configuration
`python
# torchrun launch script
"""
# On node 0 (master):
torchrun --nproc_per_node=8 \
--nnodes=4 \
--node_rank=0 \
--master_addr="10.0.0.1" \
--master_port=29500 \
train.py
# On node 1:
torchrun --nproc_per_node=8 \
--nnodes=4 \
--node_rank=1 \
--master_addr="10.0.0.1" \
--master_port=29500 \
train.py
"""
def main():
# Environment variables set by torchrun
local_rank = int(os.environ['LOCAL_RANK'])
global_rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
# Initialize process group
dist.init_process_group(backend='nccl')
# Set device
torch.cuda.set_device(local_rank)
# Training code...
train(global_rank, world_size)
if __name__ == '__main__':
main()
`
Checkpointing
`python
class DistributedCheckpointer:
"""Save and load checkpoints in distributed training."""
def __init__(self, model, optimizer, rank):
self.model = model
self.optimizer = optimizer
self.rank = rank
def save(self, path, epoch):
"""Save checkpoint (only on rank 0)."""
if self.rank == 0:
# Get model state (unwrap DDP if necessary)
if hasattr(self.model, 'module'):
model_state = self.model.module.state_dict()
else:
model_state = self.model.state_dict()
checkpoint = {
'epoch': epoch,
'model_state_dict': model_state,
'optimizer_state_dict': self.optimizer.state_dict(),
}
torch.save(checkpoint, path)
# Synchronize all processes
dist.barrier()
def load(self, path):
"""Load checkpoint on all ranks."""
# Map to current device
map_location = {'cuda:0': f'cuda:{self.rank}'}
checkpoint = torch.load(path, map_location=map_location)
# Load model state
if hasattr(self.model, 'module'):
self.model.module.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
`
Handling Failures
`python
class FaultTolerantTrainer:
"""Handle node failures gracefully."""
def __init__(self, model, checkpoint_dir):
self.model = model
self.checkpoint_dir = checkpoint_dir
self.checkpoint_interval = 100 # steps
def train(self, dataloader, epochs):
step = 0
start_epoch = 0
# Try to resume from checkpoint
checkpoint_path = self.find_latest_checkpoint()
if checkpoint_path:
start_epoch, step = self.load_checkpoint(checkpoint_path)
print(f"Resumed from epoch {start_epoch}, step {step}")
try:
for epoch in range(start_epoch, epochs):
for batch in dataloader:
self.training_step(batch)
step += 1
if step % self.checkpoint_interval == 0:
self.save_checkpoint(epoch, step)
except Exception as e:
print(f"Training failed: {e}")
self.save_checkpoint(epoch, step)
raise
def find_latest_checkpoint(self):
checkpoints = glob.glob(f"{self.checkpoint_dir}/checkpoint_*.pt")
if checkpoints:
return max(checkpoints, key=os.path.getctime)
return None
class ElasticTrainer:
"""Support elastic training (dynamic scaling)."""
def __init__(self, model):
self.model = model
def handle_membership_change(self):
"""Called when workers join or leave."""
# Get new world size
new_world_size = dist.get_world_size()
# Re-wrap model with new process group
self.model = DDP(self.model.module, device_ids=[local_rank])
# Adjust learning rate for new world size
self.adjust_learning_rate(new_world_size)
`
Profiling and Debugging
`python
class DistributedProfiler:
"""Profile distributed training."""
def __init__(self, rank):
self.rank = rank
self.timings = defaultdict(list)
@contextmanager
def time_section(self, name):
start = time.time()
yield
elapsed = time.time() - start
self.timings[name].append(elapsed)
def profile_training_step(self, model, batch, optimizer):
with self.time_section('data_load'):
inputs, targets = batch
inputs = inputs.cuda()
targets = targets.cuda()
with self.time_section('forward'):
outputs = model(inputs)
with self.time_section('loss'):
loss = F.cross_entropy(outputs, targets)
with self.time_section('backward'):
optimizer.zero_grad()
loss.backward()
with self.time_section('sync'):
# Gradient synchronization happens here with DDP
pass
with self.time_section('optimizer'):
optimizer.step()
def report(self):
if self.rank == 0:
print("\n=== Training Profile ===")
for name, times in self.timings.items():
avg_time = np.mean(times) * 1000
print(f"{name}: {avg_time:.2f} ms")
“
Conclusion
Distributed training is essential for modern AI at scale. From simple data parallelism to sophisticated hybrid strategies combining data, model, pipeline, and tensor parallelism, these techniques enable training models of unprecedented size and capability.
Key takeaways:
- Data parallelism: Simplest approach, scales well for most models
- Model parallelism: Split large models across devices
- Pipeline parallelism: Overlap computation with micro-batches
- FSDP: Shard parameters for memory efficiency
- Optimization matters: LARS/LAMB, warmup, gradient accumulation
- Communication optimization: Compression, overlapping, efficient collectives
Whether training on a few GPUs or thousands, understanding distributed training principles is crucial for anyone working with modern AI systems.