Mixed precision training has become a standard technique for accelerating deep learning. By using lower-precision numerical formats like FP16 or BF16 alongside FP32, we can dramatically speed up training while reducing memory usage. This comprehensive guide explores the principles, implementation, and best practices of mixed precision training.

Understanding Numerical Precision

Floating-Point Formats

Different floating-point formats trade off precision for efficiency:

FP32 (Single Precision):

  • 1 sign bit, 8 exponent bits, 23 mantissa bits
  • Range: ±3.4 × 10^38
  • Standard precision for neural network training

FP16 (Half Precision):

  • 1 sign bit, 5 exponent bits, 10 mantissa bits
  • Range: ±65,504
  • 2x less memory, 2-8x faster on modern GPUs

BF16 (Brain Floating Point):

  • 1 sign bit, 8 exponent bits, 7 mantissa bits
  • Same range as FP32, reduced precision
  • Developed by Google for ML workloads

TF32 (TensorFloat-32):

  • 1 sign bit, 8 exponent bits, 10 mantissa bits
  • NVIDIA format for Ampere GPUs
  • Internal format for tensor cores

python

import torch

# Check available formats

print(f"FP32: {torch.float32}")

print(f"FP16: {torch.float16}")

print(f"BF16: {torch.bfloat16}")

# Memory comparison

tensor_fp32 = torch.randn(1000, 1000, dtype=torch.float32)

tensor_fp16 = torch.randn(1000, 1000, dtype=torch.float16)

tensor_bf16 = torch.randn(1000, 1000, dtype=torch.bfloat16)

print(f"FP32 memory: {tensor_fp32.element_size() * tensor_fp32.nelement() / 1e6:.2f} MB")

print(f"FP16 memory: {tensor_fp16.element_size() * tensor_fp16.nelement() / 1e6:.2f} MB")

print(f"BF16 memory: {tensor_bf16.element_size() * tensor_bf16.nelement() / 1e6:.2f} MB")

`

Why Mixed Precision?

Using FP16 everywhere would cause problems:

  • Underflow: Small gradients become zero
  • Overflow: Large values exceed FP16 range
  • Precision loss: Weight updates may be lost

Mixed precision solves this by:

  • Computing in FP16 for speed
  • Storing master weights in FP32 for precision
  • Using loss scaling to prevent underflow

The Mixed Precision Recipe

Three Key Techniques

  1. FP16 Forward/Backward: Compute activations and gradients in FP16
  2. FP32 Master Weights: Maintain full-precision copy of weights
  3. Loss Scaling: Scale loss to prevent gradient underflow

`python

class ManualMixedPrecision:

"""Understanding the core mixed precision mechanism."""

def __init__(self, model):

# FP32 master weights

self.master_weights = {

name: param.clone().float()

for name, param in model.named_parameters()

}

# FP16 model for computation

self.model = model.half()

# Loss scale (starts high, adjusted dynamically)

self.loss_scale = 2**16

self.scale_factor = 2

self.scale_window = 2000

self.step_count = 0

self.overflow_count = 0

def forward_backward(self, inputs, targets, criterion):

# Forward in FP16

inputs = inputs.half()

outputs = self.model(inputs)

# Compute loss and scale it

loss = criterion(outputs, targets)

scaled_loss = loss * self.loss_scale

# Backward (gradients are in FP16)

scaled_loss.backward()

return loss.item()

def step(self, optimizer):

# Check for overflow in gradients

overflow = False

for param in self.model.parameters():

if param.grad is not None:

if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():

overflow = True

break

if overflow:

# Skip update, reduce scale

self.overflow_count += 1

self.loss_scale /= self.scale_factor

self.model.zero_grad()

return False

# Unscale gradients and update master weights

for name, param in self.model.named_parameters():

if param.grad is not None:

# Unscale gradient

grad_fp32 = param.grad.float() / self.loss_scale

# Update master weight

self.master_weights[name].grad = grad_fp32

# Optimizer step on master weights

optimizer.step()

# Copy master weights back to FP16 model

for name, param in self.model.named_parameters():

param.data.copy_(self.master_weights[name].half())

self.model.zero_grad()

# Maybe increase scale

self.step_count += 1

if self.step_count % self.scale_window == 0:

self.loss_scale *= self.scale_factor

return True

`

PyTorch Automatic Mixed Precision

Using torch.cuda.amp

`python

import torch

from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, dataloader, optimizer, criterion, epochs):

# Create gradient scaler

scaler = GradScaler()

for epoch in range(epochs):

for inputs, targets in dataloader:

inputs = inputs.cuda()

targets = targets.cuda()

optimizer.zero_grad()

# Forward pass with automatic casting

with autocast():

outputs = model(inputs)

loss = criterion(outputs, targets)

# Backward pass with gradient scaling

scaler.scale(loss).backward()

# Optimizer step with unscaling

scaler.step(optimizer)

# Update scaler for next iteration

scaler.update()

# More detailed control

class AMPTrainer:

def __init__(self, model, optimizer, criterion):

self.model = model.cuda()

self.optimizer = optimizer

self.criterion = criterion

self.scaler = GradScaler(

init_scale=2**16,

growth_factor=2.0,

backoff_factor=0.5,

growth_interval=2000,

enabled=True

)

def train_step(self, inputs, targets):

inputs = inputs.cuda()

targets = targets.cuda()

self.optimizer.zero_grad()

# Autocast context for forward pass

with autocast(dtype=torch.float16):

outputs = self.model(inputs)

loss = self.criterion(outputs, targets)

# Scale loss and backward

self.scaler.scale(loss).backward()

# Optional: Gradient clipping with unscaling

self.scaler.unscale_(self.optimizer)

torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

# Check for inf/nan before stepping

self.scaler.step(self.optimizer)

self.scaler.update()

return loss.item()

@torch.no_grad()

def validate(self, dataloader):

self.model.eval()

total_loss = 0

for inputs, targets in dataloader:

inputs = inputs.cuda()

targets = targets.cuda()

with autocast(dtype=torch.float16):

outputs = self.model(inputs)

loss = self.criterion(outputs, targets)

total_loss += loss.item()

self.model.train()

return total_loss / len(dataloader)

`

BFloat16 Training

`python

def train_with_bf16(model, dataloader, optimizer, criterion):

"""BFloat16 training - simpler than FP16, no loss scaling needed."""

# Check BF16 support

if not torch.cuda.is_bf16_supported():

raise RuntimeError("BF16 not supported on this GPU")

model = model.cuda()

for inputs, targets in dataloader:

inputs = inputs.cuda()

targets = targets.cuda()

optimizer.zero_grad()

# Autocast with bfloat16

with autocast(dtype=torch.bfloat16):

outputs = model(inputs)

loss = criterion(outputs, targets)

# No scaling needed for BF16

loss.backward()

optimizer.step()

class BF16Trainer:

"""BF16 training without GradScaler."""

def __init__(self, model, optimizer):

self.model = model.cuda().bfloat16() # Convert model to BF16

self.optimizer = optimizer

def train_step(self, inputs, targets):

inputs = inputs.cuda().bfloat16()

targets = targets.cuda()

self.optimizer.zero_grad()

outputs = self.model(inputs)

loss = F.cross_entropy(outputs.float(), targets) # Loss in FP32

loss.backward()

self.optimizer.step()

return loss.item()

`

Operations and Precision Selection

What Should Be FP16 vs FP32?

`python

# Operations safe for FP16

safe_for_fp16 = [

'linear', # Matrix multiplications

'conv2d', # Convolutions

'batch_norm', # Batch normalization

'relu', # ReLU activation

'max_pool', # Pooling operations

'dropout', # Dropout

]

# Operations that should stay FP32

keep_fp32 = [

'softmax', # Numerical stability

'cross_entropy', # Loss computation

'layer_norm', # Normalization

'batch_norm_stats', # Running statistics

'optimizer_step', # Weight updates

]

# Custom autocast behavior

class CustomMixedPrecision(nn.Module):

def __init__(self):

super().__init__()

self.conv = nn.Conv2d(3, 64, 3)

self.norm = nn.LayerNorm(64)

def forward(self, x):

# Autocast handles most operations

with autocast():

x = self.conv(x) # FP16

x = F.relu(x) # FP16

# Force specific operations to FP32

with autocast(enabled=False):

x = self.norm(x.float()) # FP32

return x

`

Custom Autocast Ops

`python

class FP32LayerNorm(nn.LayerNorm):

"""Layer norm that always runs in FP32."""

def forward(self, x):

input_dtype = x.dtype

x = x.float()

x = super().forward(x)

return x.to(input_dtype)

class FP32Softmax(nn.Module):

"""Softmax with FP32 computation."""

def forward(self, x, dim=-1):

return F.softmax(x.float(), dim=dim).to(x.dtype)

# Register custom autocast rules

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)

def fp32_operation(x):

"""Custom operation that always uses FP32."""

return torch.exp(x) / torch.exp(x).sum()

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float16)

def fp16_operation(x):

"""Custom operation that always uses FP16."""

return x * x

`

Loss Scaling Strategies

Static Loss Scaling

`python

class StaticLossScaling:

"""Fixed loss scale (simpler but less flexible)."""

def __init__(self, model, optimizer, scale=1024.0):

self.model = model.half().cuda()

self.optimizer = optimizer

self.scale = scale

def train_step(self, inputs, targets):

inputs = inputs.half().cuda()

targets = targets.cuda()

self.optimizer.zero_grad()

outputs = self.model(inputs)

loss = F.cross_entropy(outputs, targets)

# Scale and backward

(loss * self.scale).backward()

# Unscale gradients

for param in self.model.parameters():

if param.grad is not None:

param.grad.data /= self.scale

self.optimizer.step()

return loss.item()

`

Dynamic Loss Scaling

`python

class DynamicLossScaler:

"""Automatically adjust loss scale based on overflow detection."""

def __init__(self, init_scale=2**16, scale_factor=2.0,

scale_window=2000, min_scale=1.0, max_scale=2**24):

self.scale = init_scale

self.scale_factor = scale_factor

self.scale_window = scale_window

self.min_scale = min_scale

self.max_scale = max_scale

self.steps_since_overflow = 0

self.overflow_history = []

def has_overflow(self, params):

"""Check if any gradients contain inf or nan."""

for param in params:

if param.grad is not None:

grad_data = param.grad.data

if torch.isinf(grad_data).any() or torch.isnan(grad_data).any():

return True

return False

def update_scale(self, overflow):

"""Update scale based on overflow status."""

if overflow:

# Reduce scale on overflow

self.scale = max(self.min_scale, self.scale / self.scale_factor)

self.steps_since_overflow = 0

self.overflow_history.append(True)

else:

self.steps_since_overflow += 1

self.overflow_history.append(False)

# Increase scale after enough successful steps

if self.steps_since_overflow >= self.scale_window:

self.scale = min(self.max_scale, self.scale * self.scale_factor)

self.steps_since_overflow = 0

def state_dict(self):

return {

'scale': self.scale,

'steps_since_overflow': self.steps_since_overflow

}

def load_state_dict(self, state_dict):

self.scale = state_dict['scale']

self.steps_since_overflow = state_dict['steps_since_overflow']

`

Mixed Precision with Distributed Training

`python

from torch.nn.parallel import DistributedDataParallel as DDP

def distributed_mixed_precision_training(rank, world_size):

# Setup distributed

dist.init_process_group('nccl', rank=rank, world_size=world_size)

torch.cuda.set_device(rank)

# Create model and wrap with DDP

model = MyModel().cuda()

model = DDP(model, device_ids=[rank])

optimizer = torch.optim.Adam(model.parameters())

scaler = GradScaler()

for epoch in range(epochs):

for inputs, targets in dataloader:

inputs = inputs.cuda()

targets = targets.cuda()

optimizer.zero_grad()

with autocast():

outputs = model(inputs)

loss = F.cross_entropy(outputs, targets)

scaler.scale(loss).backward()

# Gradient clipping with unscaling

scaler.unscale_(optimizer)

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

scaler.step(optimizer)

scaler.update()

# FSDP with mixed precision

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torch.distributed.fsdp import MixedPrecision

def fsdp_mixed_precision():

mp_policy = MixedPrecision(

param_dtype=torch.float16, # Parameters in FP16

reduce_dtype=torch.float16, # Gradient reduction in FP16

buffer_dtype=torch.float16, # Buffers in FP16

cast_forward_inputs=True

)

model = FSDP(

model,

mixed_precision=mp_policy,

sharding_strategy=ShardingStrategy.FULL_SHARD

)

# No GradScaler needed with FSDP

for inputs, targets in dataloader:

with autocast():

outputs = model(inputs)

loss = F.cross_entropy(outputs, targets)

loss.backward()

optimizer.step()

optimizer.zero_grad()

`

Handling Edge Cases

Dealing with Overflow

`python

class RobustMixedPrecisionTrainer:

"""Handle edge cases in mixed precision training."""

def __init__(self, model, optimizer):

self.model = model.cuda()

self.optimizer = optimizer

self.scaler = GradScaler()

self.overflow_count = 0

self.consecutive_overflows = 0

self.max_consecutive_overflows = 10

def train_step(self, inputs, targets):

inputs = inputs.cuda()

targets = targets.cuda()

self.optimizer.zero_grad()

with autocast():

outputs = self.model(inputs)

loss = F.cross_entropy(outputs, targets)

self.scaler.scale(loss).backward()

# Check for overflow before step

self.scaler.unscale_(self.optimizer)

# Gradient norm for monitoring

grad_norm = self._get_grad_norm()

if torch.isinf(torch.tensor(grad_norm)) or torch.isnan(torch.tensor(grad_norm)):

self.overflow_count += 1

self.consecutive_overflows += 1

if self.consecutive_overflows >= self.max_consecutive_overflows:

print("Too many consecutive overflows. Consider:")

print("- Reducing learning rate")

print("- Using gradient clipping")

print("- Checking for numerical instability in model")

# Skip this update

self.optimizer.zero_grad()

return None, True

self.consecutive_overflows = 0

# Clip gradients

torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

self.scaler.step(self.optimizer)

self.scaler.update()

return loss.item(), False

def _get_grad_norm(self):

total_norm = 0

for p in self.model.parameters():

if p.grad is not None:

param_norm = p.grad.data.norm(2)

total_norm += param_norm.item() ** 2

return total_norm ** 0.5

`

Model-Specific Considerations

`python

class TransformerMixedPrecision(nn.Module):

"""Transformer with careful mixed precision handling."""

def __init__(self, d_model, nhead, num_layers):

super().__init__()

self.embedding = nn.Embedding(vocab_size, d_model)

encoder_layer = nn.TransformerEncoderLayer(

d_model, nhead,

dim_feedforward=4*d_model,

batch_first=True

)

self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

# Layer norms should be careful with precision

self.norm = FP32LayerNorm(d_model)

self.output = nn.Linear(d_model, vocab_size)

def forward(self, x):

with autocast():

x = self.embedding(x)

x = self.transformer(x)

# Force FP32 for final layer norm

x = self.norm(x)

with autocast():

x = self.output(x)

return x

class AttentionWithFP32Softmax(nn.Module):

"""Attention with FP32 softmax for numerical stability."""

def forward(self, query, key, value, mask=None):

d_k = query.size(-1)

with autocast():

scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

if mask is not None:

scores = scores.masked_fill(mask == 0, -1e9)

# Softmax in FP32

attn_weights = F.softmax(scores.float(), dim=-1).to(query.dtype)

with autocast():

output = torch.matmul(attn_weights, value)

return output

`

Benchmarking and Validation

`python

class MixedPrecisionBenchmark:

"""Compare FP32 vs mixed precision training."""

def __init__(self, model_fn, input_shape, batch_size=32):

self.model_fn = model_fn

self.input_shape = input_shape

self.batch_size = batch_size

def benchmark(self, num_iterations=100):

results = {}

# FP32 baseline

results['fp32'] = self._benchmark_fp32(num_iterations)

# FP16 with autocast

results['fp16_amp'] = self._benchmark_amp(num_iterations, torch.float16)

# BF16 if available

if torch.cuda.is_bf16_supported():

results['bf16_amp'] = self._benchmark_amp(num_iterations, torch.bfloat16)

return results

def _benchmark_fp32(self, num_iterations):

model = self.model_fn().cuda()

optimizer = torch.optim.Adam(model.parameters())

inputs = torch.randn(self.batch_size, *self.input_shape).cuda()

targets = torch.randint(0, 10, (self.batch_size,)).cuda()

# Warmup

for _ in range(10):

self._train_step_fp32(model, inputs, targets, optimizer)

torch.cuda.synchronize()

start = time.time()

for _ in range(num_iterations):

self._train_step_fp32(model, inputs, targets, optimizer)

torch.cuda.synchronize()

elapsed = time.time() - start

memory = torch.cuda.max_memory_allocated() / 1e9

return {

'time': elapsed,

'throughput': num_iterations / elapsed,

'memory_gb': memory

}

def _benchmark_amp(self, num_iterations, dtype):

model = self.model_fn().cuda()

optimizer = torch.optim.Adam(model.parameters())

scaler = GradScaler() if dtype == torch.float16 else None

inputs = torch.randn(self.batch_size, *self.input_shape).cuda()

targets = torch.randint(0, 10, (self.batch_size,)).cuda()

# Warmup

for _ in range(10):

self._train_step_amp(model, inputs, targets, optimizer, scaler, dtype)

torch.cuda.synchronize()

torch.cuda.reset_peak_memory_stats()

start = time.time()

for _ in range(num_iterations):

self._train_step_amp(model, inputs, targets, optimizer, scaler, dtype)

torch.cuda.synchronize()

elapsed = time.time() - start

memory = torch.cuda.max_memory_allocated() / 1e9

return {

'time': elapsed,

'throughput': num_iterations / elapsed,

'memory_gb': memory

}

def validate_accuracy(self, model_fp32, model_amp, test_loader):

"""Ensure mixed precision doesn't degrade accuracy significantly."""

model_fp32.eval()

model_amp.eval()

fp32_correct = 0

amp_correct = 0

total = 0

with torch.no_grad():

for inputs, targets in test_loader:

inputs = inputs.cuda()

targets = targets.cuda()

# FP32 predictions

outputs_fp32 = model_fp32(inputs)

# AMP predictions

with autocast():

outputs_amp = model_amp(inputs)

fp32_correct += (outputs_fp32.argmax(1) == targets).sum()

amp_correct += (outputs_amp.argmax(1) == targets).sum()

total += len(targets)

return {

'fp32_accuracy': fp32_correct / total,

'amp_accuracy': amp_correct / total,

'difference': abs((fp32_correct - amp_correct) / total)

}

Conclusion

Mixed precision training is a powerful technique that can double training speed while halving memory usage, with minimal impact on model accuracy. Modern frameworks make implementation straightforward, but understanding the underlying mechanics helps troubleshoot issues and optimize performance.

Key takeaways:

  1. Use autocast: Automatic casting handles most operations correctly
  2. GradScaler prevents underflow: Essential for FP16, not needed for BF16
  3. Keep some operations in FP32: Loss computation, normalization, softmax
  4. BF16 is simpler: Same range as FP32, no scaling needed
  5. Monitor for overflow: Dynamic loss scaling adapts automatically
  6. Validate accuracy: Ensure mixed precision doesn’t hurt model quality

Whether you’re training on a single GPU or across hundreds, mixed precision training is almost always worth enabling—it’s essentially free performance gains with minimal code changes.

Leave a Reply

Your email address will not be published. Required fields are marked *