Mixed precision training has become a standard technique for accelerating deep learning. By using lower-precision numerical formats like FP16 or BF16 alongside FP32, we can dramatically speed up training while reducing memory usage. This comprehensive guide explores the principles, implementation, and best practices of mixed precision training.
Understanding Numerical Precision
Floating-Point Formats
Different floating-point formats trade off precision for efficiency:
FP32 (Single Precision):
- 1 sign bit, 8 exponent bits, 23 mantissa bits
- Range: ±3.4 × 10^38
- Standard precision for neural network training
FP16 (Half Precision):
- 1 sign bit, 5 exponent bits, 10 mantissa bits
- Range: ±65,504
- 2x less memory, 2-8x faster on modern GPUs
BF16 (Brain Floating Point):
- 1 sign bit, 8 exponent bits, 7 mantissa bits
- Same range as FP32, reduced precision
- Developed by Google for ML workloads
TF32 (TensorFloat-32):
- 1 sign bit, 8 exponent bits, 10 mantissa bits
- NVIDIA format for Ampere GPUs
- Internal format for tensor cores
“python
import torch
# Check available formats
print(f"FP32: {torch.float32}")
print(f"FP16: {torch.float16}")
print(f"BF16: {torch.bfloat16}")
# Memory comparison
tensor_fp32 = torch.randn(1000, 1000, dtype=torch.float32)
tensor_fp16 = torch.randn(1000, 1000, dtype=torch.float16)
tensor_bf16 = torch.randn(1000, 1000, dtype=torch.bfloat16)
print(f"FP32 memory: {tensor_fp32.element_size() * tensor_fp32.nelement() / 1e6:.2f} MB")
print(f"FP16 memory: {tensor_fp16.element_size() * tensor_fp16.nelement() / 1e6:.2f} MB")
print(f"BF16 memory: {tensor_bf16.element_size() * tensor_bf16.nelement() / 1e6:.2f} MB")
`
Why Mixed Precision?
Using FP16 everywhere would cause problems:
- Underflow: Small gradients become zero
- Overflow: Large values exceed FP16 range
- Precision loss: Weight updates may be lost
Mixed precision solves this by:
- Computing in FP16 for speed
- Storing master weights in FP32 for precision
- Using loss scaling to prevent underflow
The Mixed Precision Recipe
Three Key Techniques
- FP16 Forward/Backward: Compute activations and gradients in FP16
- FP32 Master Weights: Maintain full-precision copy of weights
- Loss Scaling: Scale loss to prevent gradient underflow
`python
class ManualMixedPrecision:
"""Understanding the core mixed precision mechanism."""
def __init__(self, model):
# FP32 master weights
self.master_weights = {
name: param.clone().float()
for name, param in model.named_parameters()
}
# FP16 model for computation
self.model = model.half()
# Loss scale (starts high, adjusted dynamically)
self.loss_scale = 2**16
self.scale_factor = 2
self.scale_window = 2000
self.step_count = 0
self.overflow_count = 0
def forward_backward(self, inputs, targets, criterion):
# Forward in FP16
inputs = inputs.half()
outputs = self.model(inputs)
# Compute loss and scale it
loss = criterion(outputs, targets)
scaled_loss = loss * self.loss_scale
# Backward (gradients are in FP16)
scaled_loss.backward()
return loss.item()
def step(self, optimizer):
# Check for overflow in gradients
overflow = False
for param in self.model.parameters():
if param.grad is not None:
if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
overflow = True
break
if overflow:
# Skip update, reduce scale
self.overflow_count += 1
self.loss_scale /= self.scale_factor
self.model.zero_grad()
return False
# Unscale gradients and update master weights
for name, param in self.model.named_parameters():
if param.grad is not None:
# Unscale gradient
grad_fp32 = param.grad.float() / self.loss_scale
# Update master weight
self.master_weights[name].grad = grad_fp32
# Optimizer step on master weights
optimizer.step()
# Copy master weights back to FP16 model
for name, param in self.model.named_parameters():
param.data.copy_(self.master_weights[name].half())
self.model.zero_grad()
# Maybe increase scale
self.step_count += 1
if self.step_count % self.scale_window == 0:
self.loss_scale *= self.scale_factor
return True
`
PyTorch Automatic Mixed Precision
Using torch.cuda.amp
`python
import torch
from torch.cuda.amp import autocast, GradScaler
def train_with_amp(model, dataloader, optimizer, criterion, epochs):
# Create gradient scaler
scaler = GradScaler()
for epoch in range(epochs):
for inputs, targets in dataloader:
inputs = inputs.cuda()
targets = targets.cuda()
optimizer.zero_grad()
# Forward pass with automatic casting
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward pass with gradient scaling
scaler.scale(loss).backward()
# Optimizer step with unscaling
scaler.step(optimizer)
# Update scaler for next iteration
scaler.update()
# More detailed control
class AMPTrainer:
def __init__(self, model, optimizer, criterion):
self.model = model.cuda()
self.optimizer = optimizer
self.criterion = criterion
self.scaler = GradScaler(
init_scale=2**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True
)
def train_step(self, inputs, targets):
inputs = inputs.cuda()
targets = targets.cuda()
self.optimizer.zero_grad()
# Autocast context for forward pass
with autocast(dtype=torch.float16):
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
# Scale loss and backward
self.scaler.scale(loss).backward()
# Optional: Gradient clipping with unscaling
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Check for inf/nan before stepping
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
@torch.no_grad()
def validate(self, dataloader):
self.model.eval()
total_loss = 0
for inputs, targets in dataloader:
inputs = inputs.cuda()
targets = targets.cuda()
with autocast(dtype=torch.float16):
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
total_loss += loss.item()
self.model.train()
return total_loss / len(dataloader)
`
BFloat16 Training
`python
def train_with_bf16(model, dataloader, optimizer, criterion):
"""BFloat16 training - simpler than FP16, no loss scaling needed."""
# Check BF16 support
if not torch.cuda.is_bf16_supported():
raise RuntimeError("BF16 not supported on this GPU")
model = model.cuda()
for inputs, targets in dataloader:
inputs = inputs.cuda()
targets = targets.cuda()
optimizer.zero_grad()
# Autocast with bfloat16
with autocast(dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, targets)
# No scaling needed for BF16
loss.backward()
optimizer.step()
class BF16Trainer:
"""BF16 training without GradScaler."""
def __init__(self, model, optimizer):
self.model = model.cuda().bfloat16() # Convert model to BF16
self.optimizer = optimizer
def train_step(self, inputs, targets):
inputs = inputs.cuda().bfloat16()
targets = targets.cuda()
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = F.cross_entropy(outputs.float(), targets) # Loss in FP32
loss.backward()
self.optimizer.step()
return loss.item()
`
Operations and Precision Selection
What Should Be FP16 vs FP32?
`python
# Operations safe for FP16
safe_for_fp16 = [
'linear', # Matrix multiplications
'conv2d', # Convolutions
'batch_norm', # Batch normalization
'relu', # ReLU activation
'max_pool', # Pooling operations
'dropout', # Dropout
]
# Operations that should stay FP32
keep_fp32 = [
'softmax', # Numerical stability
'cross_entropy', # Loss computation
'layer_norm', # Normalization
'batch_norm_stats', # Running statistics
'optimizer_step', # Weight updates
]
# Custom autocast behavior
class CustomMixedPrecision(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.norm = nn.LayerNorm(64)
def forward(self, x):
# Autocast handles most operations
with autocast():
x = self.conv(x) # FP16
x = F.relu(x) # FP16
# Force specific operations to FP32
with autocast(enabled=False):
x = self.norm(x.float()) # FP32
return x
`
Custom Autocast Ops
`python
class FP32LayerNorm(nn.LayerNorm):
"""Layer norm that always runs in FP32."""
def forward(self, x):
input_dtype = x.dtype
x = x.float()
x = super().forward(x)
return x.to(input_dtype)
class FP32Softmax(nn.Module):
"""Softmax with FP32 computation."""
def forward(self, x, dim=-1):
return F.softmax(x.float(), dim=dim).to(x.dtype)
# Register custom autocast rules
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def fp32_operation(x):
"""Custom operation that always uses FP32."""
return torch.exp(x) / torch.exp(x).sum()
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float16)
def fp16_operation(x):
"""Custom operation that always uses FP16."""
return x * x
`
Loss Scaling Strategies
Static Loss Scaling
`python
class StaticLossScaling:
"""Fixed loss scale (simpler but less flexible)."""
def __init__(self, model, optimizer, scale=1024.0):
self.model = model.half().cuda()
self.optimizer = optimizer
self.scale = scale
def train_step(self, inputs, targets):
inputs = inputs.half().cuda()
targets = targets.cuda()
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = F.cross_entropy(outputs, targets)
# Scale and backward
(loss * self.scale).backward()
# Unscale gradients
for param in self.model.parameters():
if param.grad is not None:
param.grad.data /= self.scale
self.optimizer.step()
return loss.item()
`
Dynamic Loss Scaling
`python
class DynamicLossScaler:
"""Automatically adjust loss scale based on overflow detection."""
def __init__(self, init_scale=2**16, scale_factor=2.0,
scale_window=2000, min_scale=1.0, max_scale=2**24):
self.scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.max_scale = max_scale
self.steps_since_overflow = 0
self.overflow_history = []
def has_overflow(self, params):
"""Check if any gradients contain inf or nan."""
for param in params:
if param.grad is not None:
grad_data = param.grad.data
if torch.isinf(grad_data).any() or torch.isnan(grad_data).any():
return True
return False
def update_scale(self, overflow):
"""Update scale based on overflow status."""
if overflow:
# Reduce scale on overflow
self.scale = max(self.min_scale, self.scale / self.scale_factor)
self.steps_since_overflow = 0
self.overflow_history.append(True)
else:
self.steps_since_overflow += 1
self.overflow_history.append(False)
# Increase scale after enough successful steps
if self.steps_since_overflow >= self.scale_window:
self.scale = min(self.max_scale, self.scale * self.scale_factor)
self.steps_since_overflow = 0
def state_dict(self):
return {
'scale': self.scale,
'steps_since_overflow': self.steps_since_overflow
}
def load_state_dict(self, state_dict):
self.scale = state_dict['scale']
self.steps_since_overflow = state_dict['steps_since_overflow']
`
Mixed Precision with Distributed Training
`python
from torch.nn.parallel import DistributedDataParallel as DDP
def distributed_mixed_precision_training(rank, world_size):
# Setup distributed
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# Create model and wrap with DDP
model = MyModel().cuda()
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()
for epoch in range(epochs):
for inputs, targets in dataloader:
inputs = inputs.cuda()
targets = targets.cuda()
optimizer.zero_grad()
with autocast():
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)
scaler.scale(loss).backward()
# Gradient clipping with unscaling
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
# FSDP with mixed precision
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
def fsdp_mixed_precision():
mp_policy = MixedPrecision(
param_dtype=torch.float16, # Parameters in FP16
reduce_dtype=torch.float16, # Gradient reduction in FP16
buffer_dtype=torch.float16, # Buffers in FP16
cast_forward_inputs=True
)
model = FSDP(
model,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD
)
# No GradScaler needed with FSDP
for inputs, targets in dataloader:
with autocast():
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
`
Handling Edge Cases
Dealing with Overflow
`python
class RobustMixedPrecisionTrainer:
"""Handle edge cases in mixed precision training."""
def __init__(self, model, optimizer):
self.model = model.cuda()
self.optimizer = optimizer
self.scaler = GradScaler()
self.overflow_count = 0
self.consecutive_overflows = 0
self.max_consecutive_overflows = 10
def train_step(self, inputs, targets):
inputs = inputs.cuda()
targets = targets.cuda()
self.optimizer.zero_grad()
with autocast():
outputs = self.model(inputs)
loss = F.cross_entropy(outputs, targets)
self.scaler.scale(loss).backward()
# Check for overflow before step
self.scaler.unscale_(self.optimizer)
# Gradient norm for monitoring
grad_norm = self._get_grad_norm()
if torch.isinf(torch.tensor(grad_norm)) or torch.isnan(torch.tensor(grad_norm)):
self.overflow_count += 1
self.consecutive_overflows += 1
if self.consecutive_overflows >= self.max_consecutive_overflows:
print("Too many consecutive overflows. Consider:")
print("- Reducing learning rate")
print("- Using gradient clipping")
print("- Checking for numerical instability in model")
# Skip this update
self.optimizer.zero_grad()
return None, True
self.consecutive_overflows = 0
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item(), False
def _get_grad_norm(self):
total_norm = 0
for p in self.model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** 0.5
`
Model-Specific Considerations
`python
class TransformerMixedPrecision(nn.Module):
"""Transformer with careful mixed precision handling."""
def __init__(self, d_model, nhead, num_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model, nhead,
dim_feedforward=4*d_model,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
# Layer norms should be careful with precision
self.norm = FP32LayerNorm(d_model)
self.output = nn.Linear(d_model, vocab_size)
def forward(self, x):
with autocast():
x = self.embedding(x)
x = self.transformer(x)
# Force FP32 for final layer norm
x = self.norm(x)
with autocast():
x = self.output(x)
return x
class AttentionWithFP32Softmax(nn.Module):
"""Attention with FP32 softmax for numerical stability."""
def forward(self, query, key, value, mask=None):
d_k = query.size(-1)
with autocast():
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Softmax in FP32
attn_weights = F.softmax(scores.float(), dim=-1).to(query.dtype)
with autocast():
output = torch.matmul(attn_weights, value)
return output
`
Benchmarking and Validation
`python
class MixedPrecisionBenchmark:
"""Compare FP32 vs mixed precision training."""
def __init__(self, model_fn, input_shape, batch_size=32):
self.model_fn = model_fn
self.input_shape = input_shape
self.batch_size = batch_size
def benchmark(self, num_iterations=100):
results = {}
# FP32 baseline
results['fp32'] = self._benchmark_fp32(num_iterations)
# FP16 with autocast
results['fp16_amp'] = self._benchmark_amp(num_iterations, torch.float16)
# BF16 if available
if torch.cuda.is_bf16_supported():
results['bf16_amp'] = self._benchmark_amp(num_iterations, torch.bfloat16)
return results
def _benchmark_fp32(self, num_iterations):
model = self.model_fn().cuda()
optimizer = torch.optim.Adam(model.parameters())
inputs = torch.randn(self.batch_size, *self.input_shape).cuda()
targets = torch.randint(0, 10, (self.batch_size,)).cuda()
# Warmup
for _ in range(10):
self._train_step_fp32(model, inputs, targets, optimizer)
torch.cuda.synchronize()
start = time.time()
for _ in range(num_iterations):
self._train_step_fp32(model, inputs, targets, optimizer)
torch.cuda.synchronize()
elapsed = time.time() - start
memory = torch.cuda.max_memory_allocated() / 1e9
return {
'time': elapsed,
'throughput': num_iterations / elapsed,
'memory_gb': memory
}
def _benchmark_amp(self, num_iterations, dtype):
model = self.model_fn().cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler() if dtype == torch.float16 else None
inputs = torch.randn(self.batch_size, *self.input_shape).cuda()
targets = torch.randint(0, 10, (self.batch_size,)).cuda()
# Warmup
for _ in range(10):
self._train_step_amp(model, inputs, targets, optimizer, scaler, dtype)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start = time.time()
for _ in range(num_iterations):
self._train_step_amp(model, inputs, targets, optimizer, scaler, dtype)
torch.cuda.synchronize()
elapsed = time.time() - start
memory = torch.cuda.max_memory_allocated() / 1e9
return {
'time': elapsed,
'throughput': num_iterations / elapsed,
'memory_gb': memory
}
def validate_accuracy(self, model_fp32, model_amp, test_loader):
"""Ensure mixed precision doesn't degrade accuracy significantly."""
model_fp32.eval()
model_amp.eval()
fp32_correct = 0
amp_correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs = inputs.cuda()
targets = targets.cuda()
# FP32 predictions
outputs_fp32 = model_fp32(inputs)
# AMP predictions
with autocast():
outputs_amp = model_amp(inputs)
fp32_correct += (outputs_fp32.argmax(1) == targets).sum()
amp_correct += (outputs_amp.argmax(1) == targets).sum()
total += len(targets)
return {
'fp32_accuracy': fp32_correct / total,
'amp_accuracy': amp_correct / total,
'difference': abs((fp32_correct - amp_correct) / total)
}
“
Conclusion
Mixed precision training is a powerful technique that can double training speed while halving memory usage, with minimal impact on model accuracy. Modern frameworks make implementation straightforward, but understanding the underlying mechanics helps troubleshoot issues and optimize performance.
Key takeaways:
- Use autocast: Automatic casting handles most operations correctly
- GradScaler prevents underflow: Essential for FP16, not needed for BF16
- Keep some operations in FP32: Loss computation, normalization, softmax
- BF16 is simpler: Same range as FP32, no scaling needed
- Monitor for overflow: Dynamic loss scaling adapts automatically
- Validate accuracy: Ensure mixed precision doesn’t hurt model quality
Whether you’re training on a single GPU or across hundreds, mixed precision training is almost always worth enabling—it’s essentially free performance gains with minimal code changes.