Quantization for IoT Edge inference

Quantization for IoT Edge inference

Tried to start working on quantization for a project, used a linear affinity-based quantization to achieve 3.5-bit quantization.

from __future__ import annotations
from pathlib import Path
import torch
import torch.nn as nn
from .bignet import BIGNET_DIM, LayerNorm

# --- Constants for Quantization Logic ---
# These constants define the core parameters of our 3.5-bit quantization scheme
CODEBOOK_SIZE = 11          # Number of discrete values in our quantization codebook (-1 to 1)
PACK_GROUP_SIZE = 9         # Number of indices we pack into a single 32-bit integer for compression


class Base11Quantizer(nn.Module):
    """
    Handles the quantization of weights to an 11-value codebook and the
    dequantization back to floating-point representation.
    
    This quantizer achieves ~3.5 bits per weight through:
    1. Mapping weights to 11 discrete values (log2(11) ≈ 3.46 bits)
    2. Group-wise min/max scaling for better precision
    3. Efficient packing of 9 indices into 32 bits (9 * 3.5 ≈ 32 bits)
    """
    
    def __init__(self, group_size: int = 128):
        """
        Initialize the quantizer with a specific group size.
        
        Args:
            group_size: Number of weights quantized together with shared min/max values.
                       Smaller groups = better precision, larger groups = less metadata overhead.
        """
        super().__init__()
        self.group_size = group_size
        # Create codebook with 11 evenly spaced values from -1 to 1
        # This becomes our "vocabulary" of allowed weight values
        self.register_buffer("codebook", torch.linspace(-1, 1, CODEBOOK_SIZE))

    def quantize(self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantizes a full-precision weight tensor to 3.5-bit representation.
        
        Process:
        1. Group weights for shared scaling
        2. Find min/max per group  
        3. Normalize to [-1, 1] range
        4. Map to nearest codebook values
        5. Pack indices efficiently
        
        Args:
            weight: Full-precision weight tensor [out_features, in_features]
            
        Returns:
            tuple of:
            - packed: Compressed indices as bytes [out_features, num_groups, packed_bytes_per_group]
            - min_val: Minimum values per group [out_features, num_groups] 
            - scale: Scaling factors per group [out_features, num_groups]
        """
        
        # 1. Extract tensor dimensions
        out_features, in_features = weight.shape
        # Example: A linear layer might have shape [512, 1024] meaning 512 output neurons, 1024 inputs
        
        # 2. Calculate number of complete groups we can form
        num_groups = in_features // self.group_size
        # Example: 1024 inputs ÷ 128 group_size = 8 groups
        # Each group will have its own min/max scaling parameters for better precision
        
        # 3. Safety check - ensure we have at least one complete group
        if num_groups == 0:
            raise ValueError("Input tensor dimension is smaller than group_size, cannot quantize.")
        # This prevents errors when trying to quantize very small tensors
        
        # 4. Reshape tensor into groups, discarding remainder features
        reshaped = weight[:, :num_groups * self.group_size].view(out_features, num_groups, self.group_size)
        # Takes only features that fit into complete groups: [512, 1024] → [512, 8, 128]
        # Note: If in_features isn't divisible by group_size, remainder features are discarded
        
        # 5. Find minimum value within each group
        min_val = reshaped.amin(dim=-1, keepdim=True).to(torch.float16)
        # Computes min along last dimension (within each group)
        # Shape: [out_features, num_groups, 1] → [512, 8, 1]
        # Using float16 to reduce metadata memory overhead
        
        # 6. Find maximum value within each group  
        max_val = reshaped.amax(dim=-1, keepdim=True).to(torch.float16)
        # Computes max along last dimension (within each group)
        # Shape: [out_features, num_groups, 1] → [512, 8, 1]
        
        # 7. Calculate dynamic range (scale) for each group
        scale = (max_val - min_val).clamp(min=1e-5).to(torch.float16)
        # Range = max - min, tells us how to scale back during dequantization
        # clamp(min=1e-5) prevents division by zero for groups with constant weights
        
        # 8. Normalize weights to [0, 1] range using group-specific scaling
        normed = (reshaped - min_val) / scale
        # Standard min-max normalization: (value - min) / (max - min)
        # Now all weights are between 0 and 1
        
        # 9. Center weights to [-1, 1] range to match codebook
        centered = normed * 2 - 1
        # Linear mapping from [0, 1] to [-1, 1]: y = 2x - 1
        # This matches our codebook range
        
        # 10. Find closest codebook entry for each weight
        diffs = (centered.unsqueeze(-1) - self.codebook.view(1, 1, 1, -1)).abs()
        # Compute absolute difference between each weight and all 11 codebook values
        # centered: [512, 8, 128, 1], codebook: [1, 1, 1, 11] → diffs: [512, 8, 128, 11]
        
        # 11. Get index of closest codebook value
        indices = diffs.argmin(dim=-1).to(torch.uint8)
        # For each weight, find which codebook entry (0-10) is closest
        # Shape: [512, 8, 128], values in range [0, 10]
        # uint8 is sufficient since we only need to store values 0-10
        
        # 12. Pack indices into compressed byte format
        packed = self._pack(indices)
        # Use our efficient 9-in-4 packing scheme to compress the indices
        
        # 13. Return compressed data and metadata needed for dequantization
        return packed, min_val.squeeze(-1), scale.squeeze(-1)
        # Remove singleton dimension: [512, 8, 1] → [512, 8]
        # This gives us everything needed to reconstruct the original weights

    def dequantize(self, packed: torch.Tensor, min_val: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
        """
        Reconstructs approximate floating-point weights from quantized representation.
        
        This is the reverse of quantize():
        1. Unpack compressed indices
        2. Look up codebook values
        3. Reverse the [-1,1] to [0,1] mapping  
        4. Apply group-specific scaling and offset
        
        Args:
            packed: Compressed indices [out_features, num_groups, packed_bytes_per_group]
            min_val: Minimum values per group [out_features, num_groups]
            scale: Scaling factors per group [out_features, num_groups]
            
        Returns:
            Reconstructed weight tensor [out_features, in_features]
        """
        # 1. Unpack compressed indices back to integer form
        indices = self._unpack(packed)
        # Convert from compressed bytes back to indices in range [0, 10]
        
        # 2. Look up actual codebook values for each index
        centered = self.codebook[indices]
        # Map indices to their corresponding values in [-1, 1] range
        
        # 3. Convert from [-1, 1] back to [0, 1] range
        normed = (centered + 1) / 2
        # Reverse the centering: y = (x + 1) / 2
        
        # 4. Apply inverse of group-wise normalization
        reshaped = normed * scale.unsqueeze(-1) + min_val.unsqueeze(-1)
        # Reverse min-max scaling: reconstructed = normalized * scale + min
        # unsqueeze(-1) adds dimension for broadcasting: [512, 8] → [512, 8, 1]
        
        # 5. Flatten back to original 2D weight matrix format
        return reshaped.view(packed.size(0), -1)
        # Convert from grouped format [512, 8, 128] back to [512, 1024]

    def _pack(self, indices: torch.Tensor) -> torch.Tensor:
        """
        Packs 9 base-11 indices into a compact 4-byte (32-bit) representation.
        
        This is the core compression technique:
        - 9 indices, each needing log2(11) ≈ 3.46 bits
        - Total: 9 × 3.46 ≈ 31.1 bits ≈ 32 bits (4 bytes)
        - We treat the 9 indices as digits in a base-11 number system
        
        Args:
            indices: Integer indices [out_features, num_groups, group_size]
            
        Returns:
            Packed byte representation [out_features, num_groups, packed_bytes_per_group]
        """
        B, G, L = indices.shape  # Batch (out_features), Groups, Length (group_size)
        
        # 1. Pad indices to make length divisible by PACK_GROUP_SIZE (9)
        pad_amount = (-L) % PACK_GROUP_SIZE  # Calculate how many zeros to add
        if pad_amount > 0:
            # Add padding zeros to make the tensor divisible by 9
            padding = torch.zeros(B, G, pad_amount, dtype=indices.dtype, device=indices.device)
            indices = torch.cat([indices, padding], dim=-1)
        
        # 2. Reshape into groups of 9 indices
        indices_grouped = indices.view(B, G, -1, PACK_GROUP_SIZE)
        # Example: [512, 8, 128] → [512, 8, 14, 9] (since 128 ÷ 9 ≈ 14.2, padded to 126)
        
        # 3. Create powers of 11 for base-11 number system
        powers = torch.tensor([CODEBOOK_SIZE ** i for i in range(PACK_GROUP_SIZE)], 
                             dtype=torch.int64, device=indices.device)
        # [11^0, 11^1, 11^2, ..., 11^8] = [1, 11, 121, 1331, ...]
        # These are the place values in base-11
        
        # 4. Convert each group of 9 indices to a single base-11 number
        packed_int32 = torch.sum(indices_grouped.long() * powers.view(1, 1, 1, -1), dim=-1).to(torch.uint32)
        # Multiply each index by its place value and sum: index[0]*1 + index[1]*11 + index[2]*121 + ...
        # This gives us numbers that fit in 32 bits (since 11^9 < 2^32)
        
        # 5. Convert 32-bit integers to byte representation for storage
        packed_bytes = packed_int32.contiguous().view(torch.uint8)
        # Each uint32 becomes 4 uint8 bytes for more efficient storage
        
        return packed_bytes

    def _unpack(self, packed_bytes: torch.Tensor) -> torch.Tensor:
        """
        Unpacks byte tensor back to integer indices, reversing the compression.
        
        This reverses _pack() by treating each 32-bit integer as a base-11 number
        and extracting the individual digits (indices).
        
        Args:
            packed_bytes: Compressed byte representation
            
        Returns:
            Integer indices [out_features, num_groups, group_size]
        """
        # 1. Convert bytes back to 32-bit integers
        packed_int32 = packed_bytes.contiguous().view(torch.uint32)
        # Reinterpret every 4 bytes as a single 32-bit integer
        
        # 2. Convert to 64-bit for safe division operations
        temp_packed = packed_int32.to(torch.int64)
        # Use int64 to avoid overflow during repeated division
        
        # 3. Prepare output tensor for unpacked indices
        B, G, num_packed_ints = packed_int32.shape
        num_indices = num_packed_ints * PACK_GROUP_SIZE  # Total indices we'll extract
        indices = torch.zeros(B, G, num_indices, dtype=torch.long, device=packed_int32.device)
        
        # 4. Extract digits from base-11 numbers using repeated division
        for i in range(PACK_GROUP_SIZE):
            # Extract every 9th index (corresponding to current digit position)
            indices[:, :, i::PACK_GROUP_SIZE] = temp_packed % CODEBOOK_SIZE
            # Get remainder when divided by 11 (this gives us the current digit)
            temp_packed //= CODEBOOK_SIZE
            # Integer division by 11 (shift to next digit position)
        
        # 5. Return only the indices we actually need (trim padding)
        return indices[:, :, :self.group_size]
        # Remove any padding added during packing


class Linear11Bit(nn.Module):
    """
    A quantized linear layer that serves as a drop-in replacement for nn.Linear.
    
    Key features:
    - Stores weights in 3.5-bit quantized format (~75% memory reduction)
    - Automatically quantizes weights when loading pretrained models
    - Maintains bias in full precision for better accuracy
    - Supports arbitrary input/output dimensions with group-wise quantization
    """
    
    def __init__(self, in_features: int, out_features: int, bias: bool = True, group_size: int = 256):
        """
        Initialize a quantized linear layer.
        
        Args:
            in_features: Size of input features (same as nn.Linear)
            out_features: Size of output features (same as nn.Linear)  
            bias: Whether to include bias term (same as nn.Linear)
            group_size: Number of weights quantized together (larger = less metadata overhead)
        """
        super().__init__()
        # Store layer dimensions for compatibility with nn.Linear
        self.in_features = in_features      # Number of input neurons
        self.out_features = out_features    # Number of output neurons  
        self.group_size = group_size        # Quantization granularity
        
        # Warn user if dimensions don't divide evenly into groups
        if in_features % self.group_size != 0:
            import warnings
            warnings.warn(f"'in_features' ({in_features}) is not perfectly divisible by 'group_size' ({group_size}). "
                          "The remainder of features will be ignored during quantization.")
            # Example: 1000 features with group_size=256 → only 768 features used (3 complete groups)

        # Calculate quantization structure
        self.num_groups = in_features // self.group_size  # Number of complete groups
        self.quantizer = Base11Quantizer(group_size=self.group_size)  # Our quantization engine

        # Calculate memory layout for packed weights
        num_packs_per_group = (self.group_size + PACK_GROUP_SIZE - 1) // PACK_GROUP_SIZE
        # Ceiling division: how many 9-index packs fit in each group
        # Example: 256 weights ÷ 9 indices per pack = 28.4 → 29 packs needed
        self.packed_bytes_per_group = num_packs_per_group * 4  # 4 bytes per 32-bit pack
        # Example: 29 packs × 4 bytes = 116 bytes per group

        # Register non-trainable buffers for quantized data (these move with the model to GPU/CPU)
        self.register_buffer("quant_packed", 
                           torch.empty(out_features, self.num_groups, self.packed_bytes_per_group, dtype=torch.uint8))
        # Stores compressed weight indices as bytes
        
        self.register_buffer("min_val", 
                           torch.zeros(out_features, self.num_groups, dtype=torch.float16))
        # Stores minimum value for each group (needed for dequantization)
        
        self.register_buffer("scale", 
                           torch.ones(out_features, self.num_groups, dtype=torch.float16))
        # Stores scaling factor for each group (needed for dequantization)

        # Handle bias term (kept in full precision for better accuracy)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float32))
            # Trainable parameter in full precision
        else:
            self.register_parameter("bias", None)
            # Explicitly register None to match nn.Linear behavior

        # Register hook for automatic quantization during model loading
        self._register_load_state_dict_pre_hook(self._preload_quantized)
        # This allows seamless conversion from full-precision to quantized models

    def _preload_quantized(self, state_dict: dict, prefix: str, *args, **kwargs) -> None:
        """
        Hook that automatically quantizes weights when loading a state_dict.
        
        This enables converting pretrained full-precision models to quantized format:
        1. Look for 'weight' parameter in the state_dict
        2. Quantize it using our quantizer
        3. Store the compressed result in our buffers
        4. Remove the original weight from state_dict
        
        Args:
            state_dict: Dictionary containing model parameters
            prefix: Module name prefix (for nested modules)
        """
        # Construct the key name for this layer's weight parameter
        weight_key = prefix + "weight"
        
        # Check if this layer has a weight to quantize
        if weight_key in state_dict:
            # Extract the full-precision weight
            w = state_dict.pop(weight_key)  # pop() removes it from state_dict
            
            # Quantize the weight using our 3.5-bit quantizer
            packed, min_val, scale = self.quantizer.quantize(w)
            
            # Store the quantized data in our buffers
            self.quant_packed.copy_(packed)     # Compressed indices
            self.min_val.copy_(min_val)         # Group minimums
            self.scale.copy_(scale)             # Group scales
            
            # Note: The original weight is now removed from state_dict and replaced
            # with our compressed representation

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: dequantize weights and perform linear transformation.
        
        Process:
        1. Dequantize compressed weights to approximate floating-point
        2. Perform standard linear operation: output = input @ weight^T + bias
        
        Args:
            x: Input tensor [..., in_features]
            
        Returns:
            Output tensor [..., out_features]
        """
        # Dequantize weights for this forward pass
        with torch.no_grad():  # No gradients needed for quantized weights
            dequantized_weight = self.quantizer.dequantize(self.quant_packed, self.min_val, self.scale)
            # Convert compressed representation back to floating-point weights
        
        # Perform standard linear transformation
        return nn.functional.linear(x, dequantized_weight, self.bias)
        # This is equivalent to: x @ dequantized_weight.T + bias

    def __repr__(self) -> str:
        """String representation matching nn.Linear format."""
        return (f"{self.__class__.__name__}(in_features={self.in_features}, "
                f"out_features={self.out_features}, bias={self.bias is not None}, "
                f"group_size={self.group_size})")


class BigNet11Bit(nn.Module):
    """
    The main quantized model architecture using 3.5-bit weights.
    
    Architecture:
    - 6 residual blocks, each with 3 quantized linear layers
    - Layer normalization between blocks for training stability
    - All linear layers use 3.5-bit quantized weights
    - Achieves ~75% memory reduction compared to full-precision model
    """
    
    class Block(nn.Module):
        """
        A residual block using quantized linear layers.
        
        Structure: Linear → ReLU → Linear → ReLU → Linear → (+) → output
                      ↑_______________________________________|
                                  residual connection
        """
        def __init__(self, dim: int):
            """
            Initialize a residual block.
            
            Args:
                dim: Hidden dimension (input/output size)
            """
            super().__init__()
            # Three quantized linear layers with ReLU activations
            self.model = nn.Sequential(
                Linear11Bit(dim, dim),  # 3.5-bit quantized weights
                nn.ReLU(),              # Non-linearity
                Linear11Bit(dim, dim),  # 3.5-bit quantized weights  
                nn.ReLU(),              # Non-linearity
                Linear11Bit(dim, dim),  # 3.5-bit quantized weights
            )

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """
            Forward pass with residual connection.
            
            Args:
                x: Input tensor [batch_size, dim]
                
            Returns:
                Output tensor [batch_size, dim] with residual connection applied
            """
            # Apply the sequential layers and add residual connection
            return self.model(x) + x
            # Residual connection helps with gradient flow in deep networks

    def __init__(self):
        """Initialize the full quantized model with 6 blocks and layer normalization."""
        super().__init__()
        
        # Build the model architecture: Block → LayerNorm → Block → ... → Block
        self.model = nn.Sequential(
            # Block 1
            self.Block(BIGNET_DIM),
            LayerNorm(BIGNET_DIM),      # Normalize between blocks
            
            # Block 2  
            self.Block(BIGNET_DIM),
            LayerNorm(BIGNET_DIM),
            
            # Block 3
            self.Block(BIGNET_DIM),
            LayerNorm(BIGNET_DIM),
            
            # Block 4
            self.Block(BIGNET_DIM),
            LayerNorm(BIGNET_DIM),
            
            # Block 5
            self.Block(BIGNET_DIM),
            LayerNorm(BIGNET_DIM),
            
            # Block 6 (final block, no LayerNorm after)
            self.Block(BIGNET_DIM),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the entire quantized network.
        
        Args:
            x: Input tensor [batch_size, BIGNET_DIM]
            
        Returns:
            Output tensor [batch_size, BIGNET_DIM]
        """
        return self.model(x)


def load(path: Path | str | None) -> BigNet11Bit:
    """
    Initialize a quantized BigNet model and optionally load pretrained weights.
    
    This function handles the conversion from full-precision to quantized models:
    1. Create quantized model architecture
    2. Load full-precision state_dict (if provided)
    3. Automatic quantization happens via the _preload_quantized hooks
    
    Args:
        path: Path to saved model state_dict (.pt or .pth file)
              If None, returns model with random initialized weights
              
    Returns:
        BigNet11Bit model with quantized weights (if path provided) or random weights
        
    Example:
        # Load pretrained model and auto-convert to quantized
        model = load("pretrained_model.pt")
        
        # Create model with random quantized weights  
        model = load(None)
    """
    # Create the quantized model architecture
    model = BigNet11Bit()
    
    # Load and convert pretrained weights if path is provided
    if path is not None:
        # Load the state_dict from file
        state_dict = torch.load(path, map_location="cpu")  # Load to CPU first
        
        # Load state_dict - this triggers _preload_quantized hooks automatically
        model.load_state_dict(state_dict, strict=False)
        # strict=False allows for some parameter mismatches during quantization conversion
        
    return model
    # Model is now ready with either quantized pretrained weights or random quantized weights        

To view or add a comment, sign in

More articles by Aby Mathew

Explore content categories