| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| def modified_weight_quant(w): |
| """ Per−tensor quantization to 1.58 bits. No grouping is needed for quantization. |
| Args: |
| w: a weight tensor with shape [d, k] |
| Returns: |
| u: a quantized weight with shape [d, k] |
| """ |
| u = w.clamp(-1, 1).round() |
| return u |
|
|
| def normalize(w): |
| w = w / torch.norm(w, dim=1, keepdim=True) |
| return w |
|
|
| class QLinear(nn.Linear): |
| def __init__(self, |
| *kargs, |
| **kwargs |
| ): |
| super(QLinear, self).__init__(*kargs, **kwargs) |
| """ |
| This is only for training, and kernel optimization is needed for efficiency. |
| """ |
| self.scales = nn.Parameter(torch.ones(self.out_features)) |
| self.quantizer = modified_weight_quant |
|
|
|
|
| def forward(self, x): |
| """i |
| Args: |
| x: an input tensor with shape [n, d] |
| Returns: |
| y: an output tensor with shape [n, d] |
| """ |
| w_quant = self.weight |
| x = x.to(w_quant.device) |
| |
| w_quant = w_quant + (self.quantizer(w_quant) - w_quant).detach() |
| y = F.linear(x, w_quant) |
| |
| y = y * self.scales |
| if self.bias is not None: |
| y = y + self.bias |
| return y |