| """
|
| Quantization utilities for ternary weight representation.
|
|
|
| This module implements the core quantization functions for converting
|
| dense weights to ternary ({-1, 0, +1}) representation with appropriate
|
| scaling factors.
|
| """
|
|
|
| import torch
|
| from typing import Tuple, Optional
|
|
|
|
|
| def absmax_scale(tensor: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
|
| """
|
| Compute absmax scaling factor for quantization.
|
|
|
| The absmax scale is:
|
| scale = max(abs(tensor)) / Q_max
|
|
|
| where Q_max is the maximum quantized value (e.g., 1 for ternary).
|
|
|
| Args:
|
| tensor: Input tensor to compute scale for
|
| dim: Dimension to compute scale along (None = global, int = per-dim)
|
|
|
| Returns:
|
| Scaling factor(s)
|
|
|
| Examples:
|
| >>> W = torch.randn(512, 512)
|
| >>> scale = absmax_scale(W, dim=0) # Per output channel
|
| >>> scale.shape
|
| torch.Size([512])
|
| """
|
| if dim is None:
|
|
|
| scale = torch.max(torch.abs(tensor))
|
| else:
|
|
|
| scale = torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0]
|
|
|
| scale = scale.squeeze(dim)
|
|
|
|
|
| scale = torch.clamp(scale, min=1e-5)
|
|
|
| return scale
|
|
|
|
|
| def ternary_quantize(
|
| tensor: torch.Tensor,
|
| scale: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Quantize tensor to ternary values {-1, 0, +1}.
|
|
|
| Uses a threshold-based approach:
|
| - Values > threshold → +1
|
| - Values < -threshold → -1
|
| - Values in [-threshold, threshold] → 0
|
|
|
| The threshold is typically computed as a fraction of the scale.
|
|
|
| Args:
|
| tensor: Input tensor to quantize
|
| scale: Optional pre-computed scale (if None, compute from tensor)
|
|
|
| Returns:
|
| Ternary tensor with values in {-1, 0, +1}
|
|
|
| Notes:
|
| - The threshold determines sparsity (more zeros)
|
| - Common thresholds: 0.33 * scale or 0.5 * scale
|
| - Inspired by BitNet's weight quantization scheme
|
| """
|
|
|
| if scale is None:
|
| scale = absmax_scale(tensor, dim=None)
|
|
|
|
|
|
|
| threshold = 0.5 * scale
|
|
|
|
|
| if scale.dim() > 0:
|
|
|
| while threshold.dim() < tensor.dim():
|
| threshold = threshold.unsqueeze(-1)
|
|
|
|
|
| ternary = torch.zeros_like(tensor)
|
|
|
|
|
| ternary[tensor > threshold] = 1
|
| ternary[tensor < -threshold] = -1
|
|
|
| return ternary
|
|
|
|
|
| def weight_to_ternary(
|
| W: torch.Tensor,
|
| per_channel: bool = True,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Convert dense weights to ternary representation with scaling.
|
|
|
| This is the main quantization function that combines:
|
| 1. Scale computation (absmax per channel or global)
|
| 2. Ternary quantization
|
| 3. Return both quantized weights and scales
|
|
|
| Args:
|
| W: Dense weight matrix of shape [out_features, in_features]
|
| per_channel: If True, use per-output-channel scaling (recommended)
|
|
|
| Returns:
|
| W_ternary: Ternary weight matrix (values in {-1, 0, +1})
|
| gamma: Scaling factors (shape [out_features] or scalar)
|
|
|
| Examples:
|
| >>> W = torch.randn(512, 768)
|
| >>> W_t, gamma = weight_to_ternary(W, per_channel=True)
|
| >>> W_reconstructed = W_t * gamma.unsqueeze(1)
|
| >>> error = torch.norm(W - W_reconstructed)
|
|
|
| Notes:
|
| - Per-channel scaling preserves output scale better
|
| - The scaling factor gamma compensates for quantization
|
| - This function is used during layer initialization/conversion
|
| """
|
| if per_channel:
|
|
|
|
|
| gamma = absmax_scale(W, dim=1)
|
| else:
|
|
|
| gamma = absmax_scale(W, dim=None)
|
|
|
|
|
| W_ternary = ternary_quantize(W, gamma)
|
|
|
| return W_ternary, gamma
|
|
|
|
|
| def quantize_activations_absmax(
|
| x: torch.Tensor,
|
| bits: int = 8,
|
| per_token: bool = True,
|
| ) -> torch.Tensor:
|
| """
|
| Quantize activations using absmax scaling.
|
|
|
| BitNet quantizes both weights (ternary) and activations (8-bit).
|
| This function implements activation quantization with per-token scaling.
|
|
|
| Args:
|
| x: Input activations of shape [batch, seq_len, features]
|
| bits: Number of bits for quantization (default: 8)
|
| per_token: If True, scale per token; if False, global scaling
|
|
|
| Returns:
|
| Quantized activations (as float, simulating INT8)
|
|
|
| Notes:
|
| - Per-token scaling is important for handling outliers
|
| - Returns float for autograd compatibility
|
| - Simulates quantization without actual int8 storage
|
| """
|
|
|
| Q_max = 2 ** (bits - 1) - 1
|
| Q_min = -Q_max
|
|
|
| if per_token:
|
|
|
|
|
|
|
| scale = torch.max(torch.abs(x), dim=-1, keepdim=True)[0]
|
| scale = torch.clamp(scale, min=1e-5)
|
| else:
|
|
|
| scale = torch.max(torch.abs(x))
|
| scale = torch.clamp(scale, min=1e-5)
|
|
|
|
|
| x_scaled = x / scale * Q_max
|
| x_quant = torch.clamp(x_scaled, Q_min, Q_max)
|
| x_quant = torch.round(x_quant)
|
|
|
|
|
| x_dequant = x_quant * scale / Q_max
|
|
|
| return x_dequant
|
|
|
|
|
| def dequantize_scale(
|
| x_quant: torch.Tensor,
|
| scale: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
| Dequantize tensor back to float using scale.
|
|
|
| Simple helper for:
|
| x_float = x_quant * scale
|
|
|
| Args:
|
| x_quant: Quantized tensor (ternary or int8)
|
| scale: Scaling factors
|
|
|
| Returns:
|
| Dequantized float tensor
|
| """
|
|
|
| if scale.dim() > 0 and scale.dim() < x_quant.dim():
|
|
|
| while scale.dim() < x_quant.dim():
|
| scale = scale.unsqueeze(-1)
|
|
|
| return x_quant * scale
|
|
|