Quantization has become one of the most impactful techniques for deploying AI models efficiently. By reducing the precision of weights and activations from 32-bit floating point to lower bit-widths, quantization dramatically decreases model size, memory bandwidth requirements, and enables faster computation. This comprehensive guide explores the principles, methods, and practical applications of neural network quantization.

Understanding Quantization

What Is Quantization?

Quantization maps high-precision floating-point values to lower-precision representations:

  • FP32 → FP16: 2x compression, minimal accuracy loss
  • FP32 → INT8: 4x compression, usually <1% accuracy drop
  • FP32 → INT4: 8x compression, requires careful handling
  • FP32 → Binary: 32x compression, significant accuracy trade-off

python

# Basic quantization concept

def quantize_tensor(tensor, num_bits=8):

"""Quantize a floating-point tensor to fixed-point."""

# Find range

min_val = tensor.min()

max_val = tensor.max()

# Compute scale and zero point

qmin = 0

qmax = 2**num_bits - 1

scale = (max_val - min_val) / (qmax - qmin)

zero_point = qmin - min_val / scale

zero_point = int(round(zero_point))

# Quantize

q_tensor = torch.round(tensor / scale + zero_point)

q_tensor = torch.clamp(q_tensor, qmin, qmax).to(torch.uint8)

return q_tensor, scale, zero_point

def dequantize_tensor(q_tensor, scale, zero_point):

"""Convert quantized tensor back to floating-point."""

return scale * (q_tensor.float() - zero_point)

`

Why Quantization Works

Neural networks are inherently redundant and robust to noise:

  1. Weight distribution: Most weights cluster around zero
  2. Noise tolerance: Networks trained with noise are robust
  3. Discretization: Continuous functions can be approximated by discrete ones
  4. Statistical redundancy: Full precision isn't needed for most computations

Quantization Fundamentals

Scale (S): Maps between quantized and real values

Zero Point (Z): Ensures zero is exactly representable

Quantization formula: q = round(x/S + Z)

Dequantization: x ≈ S × (q - Z)

`python

class Quantizer:

def __init__(self, num_bits=8, symmetric=False):

self.num_bits = num_bits

self.symmetric = symmetric

if symmetric:

self.qmin = -(2**(num_bits-1))

self.qmax = 2**(num_bits-1) - 1

else:

self.qmin = 0

self.qmax = 2**num_bits - 1

def compute_scale_zp(self, tensor):

min_val = tensor.min().item()

max_val = tensor.max().item()

if self.symmetric:

max_abs = max(abs(min_val), abs(max_val))

scale = max_abs / self.qmax

zero_point = 0

else:

scale = (max_val - min_val) / (self.qmax - self.qmin)

zero_point = self.qmin - round(min_val / scale)

zero_point = max(self.qmin, min(self.qmax, zero_point))

return scale, int(zero_point)

def quantize(self, tensor, scale, zero_point):

q = torch.round(tensor / scale + zero_point)

return torch.clamp(q, self.qmin, self.qmax).to(torch.int8)

def dequantize(self, q_tensor, scale, zero_point):

return scale * (q_tensor.float() - zero_point)

`

Types of Quantization

Post-Training Quantization (PTQ)

Quantize a pre-trained model without retraining:

`python

class PostTrainingQuantizer:

def __init__(self, model, calibration_data):

self.model = model

self.calibration_data = calibration_data

self.activation_ranges = {}

def calibrate(self, num_batches=100):

"""Collect activation statistics."""

self.model.eval()

hooks = []

def hook_fn(name):

def hook(module, input, output):

if name not in self.activation_ranges:

self.activation_ranges[name] = {

'min': float('inf'),

'max': float('-inf')

}

self.activation_ranges[name]['min'] = min(

self.activation_ranges[name]['min'],

output.min().item()

)

self.activation_ranges[name]['max'] = max(

self.activation_ranges[name]['max'],

output.max().item()

)

return hook

# Register hooks

for name, module in self.model.named_modules():

if isinstance(module, (nn.Conv2d, nn.Linear)):

hooks.append(module.register_forward_hook(hook_fn(name)))

# Run calibration

with torch.no_grad():

for i, (images, _) in enumerate(self.calibration_data):

if i >= num_batches:

break

self.model(images)

# Remove hooks

for hook in hooks:

hook.remove()

def quantize_model(self):

"""Apply quantization using collected statistics."""

quantized_model = copy.deepcopy(self.model)

for name, module in quantized_model.named_modules():

if isinstance(module, nn.Conv2d):

# Quantize weights

w_scale, w_zp = self.compute_weight_params(module.weight)

module.weight.data = self.quantize_weights(

module.weight.data, w_scale, w_zp

)

# Store activation quantization parameters

if name in self.activation_ranges:

a_scale, a_zp = self.compute_activation_params(

self.activation_ranges[name]

)

module.register_buffer('act_scale', torch.tensor(a_scale))

module.register_buffer('act_zp', torch.tensor(a_zp))

return quantized_model

`

Quantization-Aware Training (QAT)

Simulate quantization during training for better accuracy:

`python

class QuantizationAwareModule(nn.Module):

"""Wrapper for QAT with fake quantization."""

def __init__(self, module, num_bits=8):

super().__init__()

self.module = module

self.num_bits = num_bits

# Learnable quantization parameters

self.register_buffer('weight_scale', torch.ones(1))

self.register_buffer('weight_zp', torch.zeros(1))

self.register_buffer('act_scale', torch.ones(1))

self.register_buffer('act_zp', torch.zeros(1))

# EMA for activation statistics

self.register_buffer('running_min', torch.zeros(1))

self.register_buffer('running_max', torch.ones(1))

self.momentum = 0.1

def fake_quantize(self, x, scale, zero_point):

"""Quantize and immediately dequantize (differentiable)."""

qmin = 0

qmax = 2**self.num_bits - 1

# Forward: quantize-dequantize

x_q = torch.clamp(

torch.round(x / scale + zero_point),

qmin, qmax

)

x_dq = scale * (x_q - zero_point)

# Straight-through estimator for gradients

return x + (x_dq - x).detach()

def forward(self, x):

# Update activation statistics

if self.training:

self.running_min = (1 - self.momentum) * self.running_min + \

self.momentum * x.min()

self.running_max = (1 - self.momentum) * self.running_max + \

self.momentum * x.max()

self.act_scale = (self.running_max - self.running_min) / 255

self.act_zp = -self.running_min / self.act_scale

# Fake quantize weights

w_scale = self.module.weight.abs().max() / 127

w_fake = self.fake_quantize(self.module.weight, w_scale, 0)

# Forward with fake-quantized weights

if isinstance(self.module, nn.Conv2d):

out = F.conv2d(x, w_fake, self.module.bias,

self.module.stride, self.module.padding)

elif isinstance(self.module, nn.Linear):

out = F.linear(x, w_fake, self.module.bias)

# Fake quantize activations

out = self.fake_quantize(out, self.act_scale, self.act_zp)

return out

class QATModel(nn.Module):

def __init__(self, model):

super().__init__()

self.model = self._wrap_modules(model)

def _wrap_modules(self, model):

for name, child in model.named_children():

if isinstance(child, (nn.Conv2d, nn.Linear)):

setattr(model, name, QuantizationAwareModule(child))

else:

self._wrap_modules(child)

return model

def forward(self, x):

return self.model(x)

`

Dynamic Quantization

Quantize weights statically but activations dynamically:

`python

import torch.quantization as quant

def dynamic_quantize_model(model):

"""Apply dynamic quantization to linear layers."""

quantized_model = torch.quantization.quantize_dynamic(

model,

{nn.Linear}, # Layers to quantize

dtype=torch.qint8

)

return quantized_model

# More control with manual dynamic quantization

class DynamicQuantizedLinear(nn.Module):

def __init__(self, linear_layer):

super().__init__()

# Quantize weights statically

self.weight_quantized = torch.quantize_per_tensor(

linear_layer.weight,

scale=linear_layer.weight.abs().max() / 127,

zero_point=0,

dtype=torch.qint8

)

self.bias = linear_layer.bias

def forward(self, x):

# Quantize input dynamically

x_scale = x.abs().max() / 127

x_quantized = torch.quantize_per_tensor(

x, scale=x_scale, zero_point=0, dtype=torch.qint8

)

# Dequantize for computation (or use quantized ops)

x_dequant = x_quantized.dequantize()

weight_dequant = self.weight_quantized.dequantize()

return F.linear(x_dequant, weight_dequant, self.bias)

`

Quantization Granularity

Per-Tensor Quantization

Single scale/zero-point for entire tensor:

`python

def per_tensor_quantize(tensor, num_bits=8):

min_val = tensor.min()

max_val = tensor.max()

scale = (max_val - min_val) / (2**num_bits - 1)

zero_point = -min_val / scale

q_tensor = torch.round(tensor / scale + zero_point)

return q_tensor.clamp(0, 2**num_bits - 1)

`

Per-Channel Quantization

Different scale/zero-point for each output channel:

`python

def per_channel_quantize(weight, num_bits=8, axis=0):

"""Per-channel quantization for conv/linear weights."""

num_channels = weight.shape[axis]

scales = []

zero_points = []

for c in range(num_channels):

channel_weights = weight.select(axis, c)

min_val = channel_weights.min()

max_val = channel_weights.max()

scale = (max_val - min_val) / (2**num_bits - 1)

zero_point = -min_val / scale

scales.append(scale)

zero_points.append(zero_point)

# Quantize each channel

q_weight = torch.zeros_like(weight)

for c in range(num_channels):

channel = weight.select(axis, c)

q_weight.select(axis, c).copy_(

torch.round(channel / scales[c] + zero_points[c]).clamp(0, 2**num_bits - 1)

)

return q_weight, scales, zero_points

`

Per-Group Quantization

For very low bit-widths (e.g., INT4):

`python

def per_group_quantize(tensor, group_size=128, num_bits=4):

"""Quantize in groups for better precision."""

original_shape = tensor.shape

tensor_flat = tensor.view(-1)

# Pad to multiple of group_size

padded_size = ((len(tensor_flat) + group_size - 1) // group_size) * group_size

tensor_padded = torch.zeros(padded_size)

tensor_padded[:len(tensor_flat)] = tensor_flat

# Reshape into groups

tensor_grouped = tensor_padded.view(-1, group_size)

# Quantize each group

scales = []

zero_points = []

q_groups = []

for group in tensor_grouped:

min_val = group.min()

max_val = group.max()

scale = (max_val - min_val) / (2**num_bits - 1)

zp = -min_val / scale if scale > 0 else 0

scales.append(scale)

zero_points.append(zp)

q_group = torch.round(group / scale + zp).clamp(0, 2**num_bits - 1)

q_groups.append(q_group)

return torch.stack(q_groups), scales, zero_points

`

Advanced Quantization Methods

Mixed-Precision Quantization

Different layers get different bit-widths:

`python

class MixedPrecisionQuantizer:

def __init__(self, model, sensitivity_analysis):

self.model = model

self.sensitivities = sensitivity_analysis # Layer -> accuracy drop

def determine_bit_widths(self, target_compression):

"""Assign bit-widths based on layer sensitivity."""

# Sort layers by sensitivity

sorted_layers = sorted(

self.sensitivities.items(),

key=lambda x: x[1],

reverse=True

)

bit_widths = {}

# Highly sensitive layers get higher precision

for i, (layer_name, sensitivity) in enumerate(sorted_layers):

if i < len(sorted_layers) * 0.2: # Top 20% most sensitive

bit_widths[layer_name] = 8

elif i < len(sorted_layers) * 0.5: # Middle

bit_widths[layer_name] = 6

else: # Least sensitive

bit_widths[layer_name] = 4

return bit_widths

def quantize_mixed(self, bit_widths):

for name, module in self.model.named_modules():

if name in bit_widths:

bits = bit_widths[name]

self._quantize_module(module, bits)

class HardwareAwareMixedPrecision:

"""Select bit-widths considering hardware constraints."""

def __init__(self, model, hardware_config):

self.model = model

self.hw_config = hardware_config

# e.g., {'supported_bits': [4, 8, 16], 'memory_limit': 100e6}

def optimize_bit_widths(self, val_loader, accuracy_target=0.99):

"""Find optimal bit-widths meeting accuracy and hardware constraints."""

import optuna

def objective(trial):

bit_widths = {}

for name, module in self.model.named_modules():

if isinstance(module, (nn.Conv2d, nn.Linear)):

bits = trial.suggest_categorical(

name,

self.hw_config['supported_bits']

)

bit_widths[name] = bits

# Quantize and evaluate

q_model = self.quantize_with_bits(bit_widths)

accuracy = self.evaluate(q_model, val_loader)

if accuracy < accuracy_target:

return float('inf')

return self.compute_model_size(bit_widths)

study = optuna.create_study(direction='minimize')

study.optimize(objective, n_trials=100)

return study.best_params

`

Learned Step Size Quantization (LSQ)

Learn quantization parameters during training:

`python

class LSQ(nn.Module):

"""Learned Step Size Quantization."""

def __init__(self, num_bits=8, symmetric=True, per_channel=False):

super().__init__()

self.num_bits = num_bits

self.symmetric = symmetric

self.per_channel = per_channel

# Learnable scale parameter

self.scale = nn.Parameter(torch.ones(1))

if symmetric:

self.qmin = -(2**(num_bits-1))

self.qmax = 2**(num_bits-1) - 1

else:

self.qmin = 0

self.qmax = 2**num_bits - 1

def forward(self, x):

# Gradient scale for stable training

grad_scale = 1.0 / (x.numel() * self.qmax) ** 0.5

# Scale with gradient scaling

scale = self.scale * grad_scale + self.scale.detach() * (1 - grad_scale)

# Quantize

x_q = torch.clamp(

torch.round(x / scale),

self.qmin, self.qmax

)

# Dequantize with straight-through estimator

x_dq = x_q * scale

return x + (x_dq - x).detach()

def init_scale(self, x):

"""Initialize scale based on tensor statistics."""

if self.symmetric:

self.scale.data = x.abs().mean() * 2 / (self.qmax ** 0.5)

else:

self.scale.data = (x.max() - x.min()) / self.qmax

`

GPTQ: Quantization for Large Language Models

`python

class GPTQ:

"""

Accurate Post-Training Quantization for GPT models.

Quantizes weights one layer at a time using Hessian information.

"""

def __init__(self, layer, num_bits=4, group_size=128):

self.layer = layer

self.num_bits = num_bits

self.group_size = group_size

# Collect Hessian information

self.H = None

self.nsamples = 0

def add_batch(self, inp):

"""Accumulate Hessian approximation from input batch."""

if len(inp.shape) == 3:

inp = inp.reshape(-1, inp.shape[-1])

inp = inp.t()

if self.H is None:

self.H = torch.zeros(

(inp.shape[0], inp.shape[0]),

device=inp.device

)

self.H += inp @ inp.t()

self.nsamples += inp.shape[1]

def quantize(self):

"""Quantize layer weights using GPTQ algorithm."""

W = self.layer.weight.data.clone()

H = self.H / self.nsamples

# Add diagonal for numerical stability

dead = torch.diag(H) == 0

H[dead, dead] = 1

# Cholesky decomposition

H = torch.linalg.cholesky(H)

H = torch.cholesky_inverse(H)

H = torch.linalg.cholesky(H, upper=True)

Hinv = H

# Quantize column by column

Q = torch.zeros_like(W)

for i in range(W.shape[1]):

w = W[:, i]

d = Hinv[i, i]

# Find optimal quantized value

q = self._quantize_weight(w)

Q[:, i] = q

# Update remaining weights

err = (w - q) / d

W[:, i:] -= err.unsqueeze(1) @ Hinv[i, i:].unsqueeze(0)

self.layer.weight.data = Q

def _quantize_weight(self, w):

"""Quantize a weight vector."""

max_val = w.abs().max()

scale = max_val / (2**(self.num_bits-1) - 1)

q = torch.round(w / scale)

q = torch.clamp(q, -(2(self.num_bits-1)), 2(self.num_bits-1) - 1)

return q * scale

`

Quantization for Inference Engines

ONNX Runtime Quantization

`python

import onnxruntime as ort

from onnxruntime.quantization import quantize_dynamic, quantize_static

def quantize_for_onnx_runtime(model_path, output_path, mode='dynamic'):

if mode == 'dynamic':

quantize_dynamic(

model_path,

output_path,

weight_type=QuantType.QInt8

)

else: # static

# Requires calibration data

quantize_static(

model_path,

output_path,

calibration_data_reader,

weight_type=QuantType.QInt8

)

`

TensorRT Quantization

`python

import tensorrt as trt

def build_int8_engine(onnx_path, calibration_cache):

logger = trt.Logger(trt.Logger.WARNING)

builder = trt.Builder(logger)

network = builder.create_network(

1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

)

parser = trt.OnnxParser(network, logger)

with open(onnx_path, 'rb') as f:

parser.parse(f.read())

config = builder.create_builder_config()

config.set_flag(trt.BuilderFlag.INT8)

config.int8_calibrator = EntropyCalibrator(calibration_cache)

engine = builder.build_engine(network, config)

return engine

`

PyTorch Native Quantization

`python

import torch.quantization

def pytorch_static_quantization(model, train_loader):

# Fuse modules for better quantization

model.eval()

model_fused = torch.quantization.fuse_modules(

model, [['conv', 'bn', 'relu']]

)

# Prepare for quantization

model_fused.qconfig = torch.quantization.get_default_qconfig('fbgemm')

model_prepared = torch.quantization.prepare(model_fused)

# Calibrate

with torch.no_grad():

for images, _ in train_loader:

model_prepared(images)

# Convert to quantized model

model_quantized = torch.quantization.convert(model_prepared)

return model_quantized

`

Practical Considerations

Quantization-Friendly Architectures

`python

class QuantizationFriendlyBlock(nn.Module):

"""Design patterns that quantize well."""

def __init__(self, in_channels, out_channels):

super().__init__()

# Avoid:

# - Large activation ranges (use BatchNorm)

# - Concatenation of features with different ranges

# - Residual connections without scaling

self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)

self.bn = nn.BatchNorm2d(out_channels) # Helps with range

self.relu = nn.ReLU() # Bounded activation

# For residual: ensure similar ranges

if in_channels != out_channels:

self.residual = nn.Sequential(

nn.Conv2d(in_channels, out_channels, 1),

nn.BatchNorm2d(out_channels)

)

else:

self.residual = nn.Identity()

def forward(self, x):

identity = self.residual(x)

out = self.relu(self.bn(self.conv(x)))

return out + identity # Both have similar ranges after BN

`

Calibration Strategies

`python

class CalibrationStrategy:

"""Different ways to determine quantization parameters."""

@staticmethod

def minmax(tensor):

"""Simple min-max calibration."""

return tensor.min(), tensor.max()

@staticmethod

def percentile(tensor, percentile=99.99):

"""Clip outliers using percentiles."""

min_val = torch.quantile(tensor, (100 - percentile) / 100)

max_val = torch.quantile(tensor, percentile / 100)

return min_val, max_val

@staticmethod

def entropy(tensor, num_bins=2048, num_bits=8):

"""Minimize KL divergence (TensorRT style)."""

hist, bin_edges = torch.histogram(tensor.float(), bins=num_bins)

best_threshold = None

best_kl = float('inf')

for i in range(128, num_bins):

# Candidate threshold

threshold = bin_edges[i]

# Quantize histogram

q_hist = torch.zeros(2**num_bits)

# ... compute quantized distribution

# Compute KL divergence

kl = F.kl_div(q_hist.log(), hist[:i].float())

if kl < best_kl:

best_kl = kl

best_threshold = threshold

return -best_threshold, best_threshold

Conclusion

Quantization is essential for efficient AI deployment. From simple post-training quantization to advanced methods like GPTQ for LLMs, these techniques enable running sophisticated models on diverse hardware.

Key takeaways:

  1. PTQ vs QAT: Post-training is faster; QAT recovers more accuracy
  2. Granularity matters: Per-channel often better than per-tensor
  3. Mixed precision: Different layers can use different bit-widths
  4. Hardware awareness: Consider target hardware’s quantization support
  5. Calibration is crucial: Good calibration data improves results
  6. Combine with other techniques: Quantization + pruning + distillation

With proper application of quantization techniques, models can achieve 2-4x speedup with minimal accuracy loss, making AI deployment practical across a wide range of devices from data centers to mobile phones.

Leave a Reply

Your email address will not be published. Required fields are marked *