Quantization for IoT Edge inference
Tried to start working on quantization for a project, used a linear affinity-based quantization to achieve 3.5-bit quantization.
from __future__ import annotations
from pathlib import Path
import torch
import torch.nn as nn
from .bignet import BIGNET_DIM, LayerNorm
# --- Constants for Quantization Logic ---
# These constants define the core parameters of our 3.5-bit quantization scheme
CODEBOOK_SIZE = 11 # Number of discrete values in our quantization codebook (-1 to 1)
PACK_GROUP_SIZE = 9 # Number of indices we pack into a single 32-bit integer for compression
class Base11Quantizer(nn.Module):
"""
Handles the quantization of weights to an 11-value codebook and the
dequantization back to floating-point representation.
This quantizer achieves ~3.5 bits per weight through:
1. Mapping weights to 11 discrete values (log2(11) ≈ 3.46 bits)
2. Group-wise min/max scaling for better precision
3. Efficient packing of 9 indices into 32 bits (9 * 3.5 ≈ 32 bits)
"""
def __init__(self, group_size: int = 128):
"""
Initialize the quantizer with a specific group size.
Args:
group_size: Number of weights quantized together with shared min/max values.
Smaller groups = better precision, larger groups = less metadata overhead.
"""
super().__init__()
self.group_size = group_size
# Create codebook with 11 evenly spaced values from -1 to 1
# This becomes our "vocabulary" of allowed weight values
self.register_buffer("codebook", torch.linspace(-1, 1, CODEBOOK_SIZE))
def quantize(self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantizes a full-precision weight tensor to 3.5-bit representation.
Process:
1. Group weights for shared scaling
2. Find min/max per group
3. Normalize to [-1, 1] range
4. Map to nearest codebook values
5. Pack indices efficiently
Args:
weight: Full-precision weight tensor [out_features, in_features]
Returns:
tuple of:
- packed: Compressed indices as bytes [out_features, num_groups, packed_bytes_per_group]
- min_val: Minimum values per group [out_features, num_groups]
- scale: Scaling factors per group [out_features, num_groups]
"""
# 1. Extract tensor dimensions
out_features, in_features = weight.shape
# Example: A linear layer might have shape [512, 1024] meaning 512 output neurons, 1024 inputs
# 2. Calculate number of complete groups we can form
num_groups = in_features // self.group_size
# Example: 1024 inputs ÷ 128 group_size = 8 groups
# Each group will have its own min/max scaling parameters for better precision
# 3. Safety check - ensure we have at least one complete group
if num_groups == 0:
raise ValueError("Input tensor dimension is smaller than group_size, cannot quantize.")
# This prevents errors when trying to quantize very small tensors
# 4. Reshape tensor into groups, discarding remainder features
reshaped = weight[:, :num_groups * self.group_size].view(out_features, num_groups, self.group_size)
# Takes only features that fit into complete groups: [512, 1024] → [512, 8, 128]
# Note: If in_features isn't divisible by group_size, remainder features are discarded
# 5. Find minimum value within each group
min_val = reshaped.amin(dim=-1, keepdim=True).to(torch.float16)
# Computes min along last dimension (within each group)
# Shape: [out_features, num_groups, 1] → [512, 8, 1]
# Using float16 to reduce metadata memory overhead
# 6. Find maximum value within each group
max_val = reshaped.amax(dim=-1, keepdim=True).to(torch.float16)
# Computes max along last dimension (within each group)
# Shape: [out_features, num_groups, 1] → [512, 8, 1]
# 7. Calculate dynamic range (scale) for each group
scale = (max_val - min_val).clamp(min=1e-5).to(torch.float16)
# Range = max - min, tells us how to scale back during dequantization
# clamp(min=1e-5) prevents division by zero for groups with constant weights
# 8. Normalize weights to [0, 1] range using group-specific scaling
normed = (reshaped - min_val) / scale
# Standard min-max normalization: (value - min) / (max - min)
# Now all weights are between 0 and 1
# 9. Center weights to [-1, 1] range to match codebook
centered = normed * 2 - 1
# Linear mapping from [0, 1] to [-1, 1]: y = 2x - 1
# This matches our codebook range
# 10. Find closest codebook entry for each weight
diffs = (centered.unsqueeze(-1) - self.codebook.view(1, 1, 1, -1)).abs()
# Compute absolute difference between each weight and all 11 codebook values
# centered: [512, 8, 128, 1], codebook: [1, 1, 1, 11] → diffs: [512, 8, 128, 11]
# 11. Get index of closest codebook value
indices = diffs.argmin(dim=-1).to(torch.uint8)
# For each weight, find which codebook entry (0-10) is closest
# Shape: [512, 8, 128], values in range [0, 10]
# uint8 is sufficient since we only need to store values 0-10
# 12. Pack indices into compressed byte format
packed = self._pack(indices)
# Use our efficient 9-in-4 packing scheme to compress the indices
# 13. Return compressed data and metadata needed for dequantization
return packed, min_val.squeeze(-1), scale.squeeze(-1)
# Remove singleton dimension: [512, 8, 1] → [512, 8]
# This gives us everything needed to reconstruct the original weights
def dequantize(self, packed: torch.Tensor, min_val: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Reconstructs approximate floating-point weights from quantized representation.
This is the reverse of quantize():
1. Unpack compressed indices
2. Look up codebook values
3. Reverse the [-1,1] to [0,1] mapping
4. Apply group-specific scaling and offset
Args:
packed: Compressed indices [out_features, num_groups, packed_bytes_per_group]
min_val: Minimum values per group [out_features, num_groups]
scale: Scaling factors per group [out_features, num_groups]
Returns:
Reconstructed weight tensor [out_features, in_features]
"""
# 1. Unpack compressed indices back to integer form
indices = self._unpack(packed)
# Convert from compressed bytes back to indices in range [0, 10]
# 2. Look up actual codebook values for each index
centered = self.codebook[indices]
# Map indices to their corresponding values in [-1, 1] range
# 3. Convert from [-1, 1] back to [0, 1] range
normed = (centered + 1) / 2
# Reverse the centering: y = (x + 1) / 2
# 4. Apply inverse of group-wise normalization
reshaped = normed * scale.unsqueeze(-1) + min_val.unsqueeze(-1)
# Reverse min-max scaling: reconstructed = normalized * scale + min
# unsqueeze(-1) adds dimension for broadcasting: [512, 8] → [512, 8, 1]
# 5. Flatten back to original 2D weight matrix format
return reshaped.view(packed.size(0), -1)
# Convert from grouped format [512, 8, 128] back to [512, 1024]
def _pack(self, indices: torch.Tensor) -> torch.Tensor:
"""
Packs 9 base-11 indices into a compact 4-byte (32-bit) representation.
This is the core compression technique:
- 9 indices, each needing log2(11) ≈ 3.46 bits
- Total: 9 × 3.46 ≈ 31.1 bits ≈ 32 bits (4 bytes)
- We treat the 9 indices as digits in a base-11 number system
Args:
indices: Integer indices [out_features, num_groups, group_size]
Returns:
Packed byte representation [out_features, num_groups, packed_bytes_per_group]
"""
B, G, L = indices.shape # Batch (out_features), Groups, Length (group_size)
# 1. Pad indices to make length divisible by PACK_GROUP_SIZE (9)
pad_amount = (-L) % PACK_GROUP_SIZE # Calculate how many zeros to add
if pad_amount > 0:
# Add padding zeros to make the tensor divisible by 9
padding = torch.zeros(B, G, pad_amount, dtype=indices.dtype, device=indices.device)
indices = torch.cat([indices, padding], dim=-1)
# 2. Reshape into groups of 9 indices
indices_grouped = indices.view(B, G, -1, PACK_GROUP_SIZE)
# Example: [512, 8, 128] → [512, 8, 14, 9] (since 128 ÷ 9 ≈ 14.2, padded to 126)
# 3. Create powers of 11 for base-11 number system
powers = torch.tensor([CODEBOOK_SIZE ** i for i in range(PACK_GROUP_SIZE)],
dtype=torch.int64, device=indices.device)
# [11^0, 11^1, 11^2, ..., 11^8] = [1, 11, 121, 1331, ...]
# These are the place values in base-11
# 4. Convert each group of 9 indices to a single base-11 number
packed_int32 = torch.sum(indices_grouped.long() * powers.view(1, 1, 1, -1), dim=-1).to(torch.uint32)
# Multiply each index by its place value and sum: index[0]*1 + index[1]*11 + index[2]*121 + ...
# This gives us numbers that fit in 32 bits (since 11^9 < 2^32)
# 5. Convert 32-bit integers to byte representation for storage
packed_bytes = packed_int32.contiguous().view(torch.uint8)
# Each uint32 becomes 4 uint8 bytes for more efficient storage
return packed_bytes
def _unpack(self, packed_bytes: torch.Tensor) -> torch.Tensor:
"""
Unpacks byte tensor back to integer indices, reversing the compression.
This reverses _pack() by treating each 32-bit integer as a base-11 number
and extracting the individual digits (indices).
Args:
packed_bytes: Compressed byte representation
Returns:
Integer indices [out_features, num_groups, group_size]
"""
# 1. Convert bytes back to 32-bit integers
packed_int32 = packed_bytes.contiguous().view(torch.uint32)
# Reinterpret every 4 bytes as a single 32-bit integer
# 2. Convert to 64-bit for safe division operations
temp_packed = packed_int32.to(torch.int64)
# Use int64 to avoid overflow during repeated division
# 3. Prepare output tensor for unpacked indices
B, G, num_packed_ints = packed_int32.shape
num_indices = num_packed_ints * PACK_GROUP_SIZE # Total indices we'll extract
indices = torch.zeros(B, G, num_indices, dtype=torch.long, device=packed_int32.device)
# 4. Extract digits from base-11 numbers using repeated division
for i in range(PACK_GROUP_SIZE):
# Extract every 9th index (corresponding to current digit position)
indices[:, :, i::PACK_GROUP_SIZE] = temp_packed % CODEBOOK_SIZE
# Get remainder when divided by 11 (this gives us the current digit)
temp_packed //= CODEBOOK_SIZE
# Integer division by 11 (shift to next digit position)
# 5. Return only the indices we actually need (trim padding)
return indices[:, :, :self.group_size]
# Remove any padding added during packing
class Linear11Bit(nn.Module):
"""
A quantized linear layer that serves as a drop-in replacement for nn.Linear.
Key features:
- Stores weights in 3.5-bit quantized format (~75% memory reduction)
- Automatically quantizes weights when loading pretrained models
- Maintains bias in full precision for better accuracy
- Supports arbitrary input/output dimensions with group-wise quantization
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True, group_size: int = 256):
"""
Initialize a quantized linear layer.
Args:
in_features: Size of input features (same as nn.Linear)
out_features: Size of output features (same as nn.Linear)
bias: Whether to include bias term (same as nn.Linear)
group_size: Number of weights quantized together (larger = less metadata overhead)
"""
super().__init__()
# Store layer dimensions for compatibility with nn.Linear
self.in_features = in_features # Number of input neurons
self.out_features = out_features # Number of output neurons
self.group_size = group_size # Quantization granularity
# Warn user if dimensions don't divide evenly into groups
if in_features % self.group_size != 0:
import warnings
warnings.warn(f"'in_features' ({in_features}) is not perfectly divisible by 'group_size' ({group_size}). "
"The remainder of features will be ignored during quantization.")
# Example: 1000 features with group_size=256 → only 768 features used (3 complete groups)
# Calculate quantization structure
self.num_groups = in_features // self.group_size # Number of complete groups
self.quantizer = Base11Quantizer(group_size=self.group_size) # Our quantization engine
# Calculate memory layout for packed weights
num_packs_per_group = (self.group_size + PACK_GROUP_SIZE - 1) // PACK_GROUP_SIZE
# Ceiling division: how many 9-index packs fit in each group
# Example: 256 weights ÷ 9 indices per pack = 28.4 → 29 packs needed
self.packed_bytes_per_group = num_packs_per_group * 4 # 4 bytes per 32-bit pack
# Example: 29 packs × 4 bytes = 116 bytes per group
# Register non-trainable buffers for quantized data (these move with the model to GPU/CPU)
self.register_buffer("quant_packed",
torch.empty(out_features, self.num_groups, self.packed_bytes_per_group, dtype=torch.uint8))
# Stores compressed weight indices as bytes
self.register_buffer("min_val",
torch.zeros(out_features, self.num_groups, dtype=torch.float16))
# Stores minimum value for each group (needed for dequantization)
self.register_buffer("scale",
torch.ones(out_features, self.num_groups, dtype=torch.float16))
# Stores scaling factor for each group (needed for dequantization)
# Handle bias term (kept in full precision for better accuracy)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float32))
# Trainable parameter in full precision
else:
self.register_parameter("bias", None)
# Explicitly register None to match nn.Linear behavior
# Register hook for automatic quantization during model loading
self._register_load_state_dict_pre_hook(self._preload_quantized)
# This allows seamless conversion from full-precision to quantized models
def _preload_quantized(self, state_dict: dict, prefix: str, *args, **kwargs) -> None:
"""
Hook that automatically quantizes weights when loading a state_dict.
This enables converting pretrained full-precision models to quantized format:
1. Look for 'weight' parameter in the state_dict
2. Quantize it using our quantizer
3. Store the compressed result in our buffers
4. Remove the original weight from state_dict
Args:
state_dict: Dictionary containing model parameters
prefix: Module name prefix (for nested modules)
"""
# Construct the key name for this layer's weight parameter
weight_key = prefix + "weight"
# Check if this layer has a weight to quantize
if weight_key in state_dict:
# Extract the full-precision weight
w = state_dict.pop(weight_key) # pop() removes it from state_dict
# Quantize the weight using our 3.5-bit quantizer
packed, min_val, scale = self.quantizer.quantize(w)
# Store the quantized data in our buffers
self.quant_packed.copy_(packed) # Compressed indices
self.min_val.copy_(min_val) # Group minimums
self.scale.copy_(scale) # Group scales
# Note: The original weight is now removed from state_dict and replaced
# with our compressed representation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass: dequantize weights and perform linear transformation.
Process:
1. Dequantize compressed weights to approximate floating-point
2. Perform standard linear operation: output = input @ weight^T + bias
Args:
x: Input tensor [..., in_features]
Returns:
Output tensor [..., out_features]
"""
# Dequantize weights for this forward pass
with torch.no_grad(): # No gradients needed for quantized weights
dequantized_weight = self.quantizer.dequantize(self.quant_packed, self.min_val, self.scale)
# Convert compressed representation back to floating-point weights
# Perform standard linear transformation
return nn.functional.linear(x, dequantized_weight, self.bias)
# This is equivalent to: x @ dequantized_weight.T + bias
def __repr__(self) -> str:
"""String representation matching nn.Linear format."""
return (f"{self.__class__.__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.bias is not None}, "
f"group_size={self.group_size})")
class BigNet11Bit(nn.Module):
"""
The main quantized model architecture using 3.5-bit weights.
Architecture:
- 6 residual blocks, each with 3 quantized linear layers
- Layer normalization between blocks for training stability
- All linear layers use 3.5-bit quantized weights
- Achieves ~75% memory reduction compared to full-precision model
"""
class Block(nn.Module):
"""
A residual block using quantized linear layers.
Structure: Linear → ReLU → Linear → ReLU → Linear → (+) → output
↑_______________________________________|
residual connection
"""
def __init__(self, dim: int):
"""
Initialize a residual block.
Args:
dim: Hidden dimension (input/output size)
"""
super().__init__()
# Three quantized linear layers with ReLU activations
self.model = nn.Sequential(
Linear11Bit(dim, dim), # 3.5-bit quantized weights
nn.ReLU(), # Non-linearity
Linear11Bit(dim, dim), # 3.5-bit quantized weights
nn.ReLU(), # Non-linearity
Linear11Bit(dim, dim), # 3.5-bit quantized weights
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with residual connection.
Args:
x: Input tensor [batch_size, dim]
Returns:
Output tensor [batch_size, dim] with residual connection applied
"""
# Apply the sequential layers and add residual connection
return self.model(x) + x
# Residual connection helps with gradient flow in deep networks
def __init__(self):
"""Initialize the full quantized model with 6 blocks and layer normalization."""
super().__init__()
# Build the model architecture: Block → LayerNorm → Block → ... → Block
self.model = nn.Sequential(
# Block 1
self.Block(BIGNET_DIM),
LayerNorm(BIGNET_DIM), # Normalize between blocks
# Block 2
self.Block(BIGNET_DIM),
LayerNorm(BIGNET_DIM),
# Block 3
self.Block(BIGNET_DIM),
LayerNorm(BIGNET_DIM),
# Block 4
self.Block(BIGNET_DIM),
LayerNorm(BIGNET_DIM),
# Block 5
self.Block(BIGNET_DIM),
LayerNorm(BIGNET_DIM),
# Block 6 (final block, no LayerNorm after)
self.Block(BIGNET_DIM),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the entire quantized network.
Args:
x: Input tensor [batch_size, BIGNET_DIM]
Returns:
Output tensor [batch_size, BIGNET_DIM]
"""
return self.model(x)
def load(path: Path | str | None) -> BigNet11Bit:
"""
Initialize a quantized BigNet model and optionally load pretrained weights.
This function handles the conversion from full-precision to quantized models:
1. Create quantized model architecture
2. Load full-precision state_dict (if provided)
3. Automatic quantization happens via the _preload_quantized hooks
Args:
path: Path to saved model state_dict (.pt or .pth file)
If None, returns model with random initialized weights
Returns:
BigNet11Bit model with quantized weights (if path provided) or random weights
Example:
# Load pretrained model and auto-convert to quantized
model = load("pretrained_model.pt")
# Create model with random quantized weights
model = load(None)
"""
# Create the quantized model architecture
model = BigNet11Bit()
# Load and convert pretrained weights if path is provided
if path is not None:
# Load the state_dict from file
state_dict = torch.load(path, map_location="cpu") # Load to CPU first
# Load state_dict - this triggers _preload_quantized hooks automatically
model.load_state_dict(state_dict, strict=False)
# strict=False allows for some parameter mismatches during quantization conversion
return model
# Model is now ready with either quantized pretrained weights or random quantized weights