| """
|
| BitLinear layer implementations.
|
|
|
| This module provides nn.Module wrappers around the functional implementations,
|
| providing a drop-in replacement for nn.Linear with ternary weights.
|
| """
|
|
|
| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional
|
|
|
| from .functional import (
|
| bitlinear_python,
|
| greedy_ternary_decomposition,
|
| multi_ternary_linear_python,
|
| )
|
| from .quantization import weight_to_ternary
|
|
|
|
|
| class BitLinear(nn.Module):
|
| """
|
| BitLinear layer: drop-in replacement for nn.Linear with ternary weights.
|
|
|
| This layer uses ternary weights ({-1, 0, +1}) instead of full-precision
|
| weights, achieving ~20x memory compression while maintaining competitive
|
| performance on Transformer models.
|
|
|
| Interface matches nn.Linear:
|
| - Same initialization arguments (in_features, out_features, bias)
|
| - Same forward signature
|
| - Can replace nn.Linear in existing architectures
|
|
|
| Example:
|
| >>> # Standard Linear
|
| >>> linear = nn.Linear(512, 512)
|
| >>> # BitLinear replacement
|
| >>> bitlinear = BitLinear(512, 512)
|
| >>> x = torch.randn(32, 128, 512)
|
| >>> output = bitlinear(x) # Same interface
|
|
|
| Notes:
|
| - Weights are quantized to ternary on initialization or conversion
|
| - Stores ternary weights + scaling factors (gamma)
|
| - Forward pass uses efficient ternary matrix multiplication
|
| - Can be trained with QAT (Quantization-Aware Training)
|
|
|
| Attributes:
|
| in_features: Input dimension
|
| out_features: Output dimension
|
| W_ternary: Ternary weight matrix [out_features, in_features]
|
| gamma: Per-output scaling factors [out_features]
|
| bias: Optional bias term [out_features]
|
| """
|
|
|
| def __init__(
|
| self,
|
| in_features: int,
|
| out_features: int,
|
| bias: bool = True,
|
| device: Optional[torch.device] = None,
|
| dtype: Optional[torch.dtype] = None,
|
| ):
|
| """
|
| Initialize BitLinear layer.
|
|
|
| Args:
|
| in_features: Size of each input sample
|
| out_features: Size of each output sample
|
| bias: If True, add learnable bias (default: True)
|
| device: Device to place parameters on
|
| dtype: Data type for parameters
|
|
|
| TODO:
|
| - Initialize dense weights using standard initialization (e.g., kaiming_uniform_)
|
| - Convert to ternary using weight_to_ternary()
|
| - Register W_ternary and gamma as parameters or buffers
|
| - Initialize bias if needed
|
| - Decide on training strategy (fixed ternary vs. QAT)
|
| """
|
| super().__init__()
|
|
|
| self.in_features = in_features
|
| self.out_features = out_features
|
|
|
|
|
|
|
| self.W_ternary = nn.Parameter(torch.zeros(out_features, in_features))
|
| self.gamma = nn.Parameter(torch.ones(out_features))
|
|
|
|
|
| if bias:
|
| self.bias = nn.Parameter(torch.zeros(out_features))
|
| else:
|
| self.register_parameter('bias', None)
|
|
|
|
|
| self.reset_parameters()
|
|
|
| def reset_parameters(self) -> None:
|
| """
|
| Initialize layer parameters.
|
|
|
| Strategy:
|
| 1. Initialize dense weights using standard scheme (kaiming_uniform_)
|
| 2. Quantize to ternary using weight_to_ternary()
|
| 3. Store ternary weights and scaling factors
|
| """
|
|
|
| W_dense = torch.empty(self.out_features, self.in_features)
|
| nn.init.kaiming_uniform_(W_dense, a=math.sqrt(5))
|
|
|
|
|
| W_ternary, gamma = weight_to_ternary(W_dense, per_channel=True)
|
| self.W_ternary.data.copy_(W_ternary)
|
| self.gamma.data.copy_(gamma)
|
|
|
|
|
| if self.bias is not None:
|
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W_dense)
|
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| nn.init.uniform_(self.bias, -bound, bound)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Forward pass through BitLinear layer.
|
|
|
| Args:
|
| x: Input tensor of shape [..., in_features]
|
|
|
| Returns:
|
| Output tensor of shape [..., out_features]
|
| """
|
| return bitlinear_python(x, self.W_ternary, self.gamma, self.bias)
|
|
|
| @classmethod
|
| def from_linear(cls, linear: nn.Linear) -> 'BitLinear':
|
| """
|
| Convert a standard nn.Linear layer to BitLinear.
|
|
|
| This allows converting pre-trained models to use ternary weights.
|
|
|
| Args:
|
| linear: Standard nn.Linear layer to convert
|
|
|
| Returns:
|
| BitLinear layer with quantized weights
|
|
|
| Example:
|
| >>> linear = nn.Linear(512, 512)
|
| >>> # ... train linear ...
|
| >>> bitlinear = BitLinear.from_linear(linear)
|
| """
|
|
|
| bitlinear = cls(
|
| linear.in_features,
|
| linear.out_features,
|
| bias=linear.bias is not None,
|
| device=linear.weight.device,
|
| dtype=linear.weight.dtype,
|
| )
|
|
|
|
|
| W_ternary, gamma = weight_to_ternary(linear.weight.data, per_channel=True)
|
| bitlinear.W_ternary.data.copy_(W_ternary)
|
| bitlinear.gamma.data.copy_(gamma)
|
|
|
|
|
| if linear.bias is not None:
|
| bitlinear.bias.data.copy_(linear.bias.data)
|
|
|
| return bitlinear
|
|
|
| def extra_repr(self) -> str:
|
| """String representation for print()."""
|
| return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
|
|
|
|
|
| class MultiTernaryLinear(nn.Module):
|
| """
|
| Multi-component ternary linear layer.
|
|
|
| Represents a linear layer as a sum of k ternary components:
|
| output = sum_{i=1}^k (x @ W_i^T * gamma_i) + bias
|
|
|
| This provides better approximation of dense weights compared to single
|
| ternary quantization, at the cost of k× more computation.
|
|
|
| References:
|
| - JMLR paper on ternary representations: https://jmlr.org/papers/volume26/24-2050/24-2050.pdf
|
| - Greedy ternary decomposition for neural networks
|
|
|
| Attributes:
|
| in_features: Input dimension
|
| out_features: Output dimension
|
| k: Number of ternary components
|
| W_ternary: Stacked ternary weights [k, out_features, in_features]
|
| gammas: Stacked scaling factors [k, out_features]
|
| bias: Optional bias term [out_features]
|
|
|
| Example:
|
| >>> # Single ternary component (equivalent to BitLinear)
|
| >>> layer = MultiTernaryLinear(512, 512, k=1)
|
| >>> # Multiple components for better approximation
|
| >>> layer = MultiTernaryLinear(512, 512, k=4)
|
| """
|
|
|
| def __init__(
|
| self,
|
| in_features: int,
|
| out_features: int,
|
| k: int = 2,
|
| bias: bool = True,
|
| device: Optional[torch.device] = None,
|
| dtype: Optional[torch.dtype] = None,
|
| ):
|
| """
|
| Initialize MultiTernaryLinear layer.
|
|
|
| Args:
|
| in_features: Size of each input sample
|
| out_features: Size of each output sample
|
| k: Number of ternary components (typically 2-4)
|
| bias: If True, add learnable bias
|
| device: Device to place parameters on
|
| dtype: Data type for parameters
|
|
|
| TODO:
|
| - Initialize dense weights
|
| - Apply greedy_ternary_decomposition with k components
|
| - Store stacked ternary weights and gammas
|
| - Initialize bias
|
| """
|
| super().__init__()
|
|
|
| self.in_features = in_features
|
| self.out_features = out_features
|
| self.k = k
|
|
|
|
|
|
|
| self.W_ternary = nn.Parameter(torch.zeros(k, out_features, in_features))
|
| self.gammas = nn.Parameter(torch.ones(k, out_features))
|
|
|
| if bias:
|
| self.bias = nn.Parameter(torch.zeros(out_features))
|
| else:
|
| self.register_parameter('bias', None)
|
|
|
|
|
| self.reset_parameters()
|
|
|
| def reset_parameters(self) -> None:
|
| """
|
| Initialize layer parameters using greedy ternary decomposition.
|
| """
|
|
|
| W_dense = torch.empty(self.out_features, self.in_features)
|
| nn.init.kaiming_uniform_(W_dense, a=math.sqrt(5))
|
|
|
|
|
| W_ternary_list, gamma_list = greedy_ternary_decomposition(W_dense, self.k)
|
|
|
|
|
| self.W_ternary.data.copy_(W_ternary_list)
|
| self.gammas.data.copy_(gamma_list)
|
|
|
|
|
| if self.bias is not None:
|
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(W_dense)
|
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| nn.init.uniform_(self.bias, -bound, bound)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Forward pass through multi-ternary layer.
|
|
|
| Args:
|
| x: Input tensor of shape [..., in_features]
|
|
|
| Returns:
|
| Output tensor of shape [..., out_features]
|
| """
|
| return multi_ternary_linear_python(x, self.W_ternary, self.gammas, self.bias)
|
|
|
| @classmethod
|
| def from_linear(cls, linear: nn.Linear, k: int = 2) -> 'MultiTernaryLinear':
|
| """
|
| Convert nn.Linear to MultiTernaryLinear using greedy decomposition.
|
|
|
| Args:
|
| linear: Standard nn.Linear layer
|
| k: Number of ternary components
|
|
|
| Returns:
|
| MultiTernaryLinear layer
|
| """
|
|
|
| multi_ternary = cls(
|
| linear.in_features,
|
| linear.out_features,
|
| k=k,
|
| bias=linear.bias is not None,
|
| device=linear.weight.device,
|
| dtype=linear.weight.dtype,
|
| )
|
|
|
|
|
| W_ternary_list, gamma_list = greedy_ternary_decomposition(linear.weight.data, k)
|
| multi_ternary.W_ternary.data.copy_(W_ternary_list)
|
| multi_ternary.gammas.data.copy_(gamma_list)
|
|
|
|
|
| if linear.bias is not None:
|
| multi_ternary.bias.data.copy_(linear.bias.data)
|
|
|
| return multi_ternary
|
|
|
| def extra_repr(self) -> str:
|
| """String representation."""
|
| return f'in_features={self.in_features}, out_features={self.out_features}, k={self.k}, bias={self.bias is not None}'
|
|
|
|
|
| def convert_linear_to_bitlinear(
|
| module: nn.Module,
|
| inplace: bool = True,
|
| ) -> nn.Module:
|
| """
|
| Recursively convert all nn.Linear layers in a module to BitLinear.
|
|
|
| This utility function walks through a model and replaces all Linear layers
|
| with BitLinear layers, useful for converting pre-trained models.
|
|
|
| Args:
|
| module: PyTorch module (e.g., a Transformer model)
|
| inplace: If True, modify module in place; if False, return a copy
|
|
|
| Returns:
|
| Module with Linear layers replaced by BitLinear
|
|
|
| Example:
|
| >>> model = transformers.GPT2Model.from_pretrained('gpt2')
|
| >>> model = convert_linear_to_bitlinear(model)
|
| >>> # All Linear layers are now BitLinear
|
| """
|
| if not inplace:
|
| import copy
|
| module = copy.deepcopy(module)
|
|
|
|
|
| for name, child in module.named_children():
|
| if isinstance(child, nn.Linear):
|
|
|
| setattr(module, name, BitLinear.from_linear(child))
|
| else:
|
|
|
| convert_linear_to_bitlinear(child, inplace=True)
|
|
|
| return module
|
|
|