| """
|
| Example: Using BitLinear as a drop-in replacement for nn.Linear in a Transformer.
|
|
|
| This example demonstrates:
|
| 1. Creating a simple Transformer block with standard nn.Linear
|
| 2. Converting it to use BitLinear layers
|
| 3. Running forward passes to verify compatibility
|
| 4. Comparing memory usage and output similarity
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional
|
|
|
| from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
|
|
|
|
| class TransformerBlock(nn.Module):
|
| """
|
| Simplified Transformer block for demonstration.
|
|
|
| Contains:
|
| - Multi-head self-attention with linear projections
|
| - Feed-forward network with two linear layers
|
| - Layer normalization and residual connections
|
| """
|
|
|
| def __init__(
|
| self,
|
| d_model: int = 512,
|
| nhead: int = 8,
|
| dim_feedforward: int = 2048,
|
| dropout: float = 0.1,
|
| ):
|
| super().__init__()
|
|
|
|
|
| self.d_model = d_model
|
| self.nhead = nhead
|
| self.d_k = d_model // nhead
|
|
|
|
|
| self.q_proj = nn.Linear(d_model, d_model)
|
| self.k_proj = nn.Linear(d_model, d_model)
|
| self.v_proj = nn.Linear(d_model, d_model)
|
| self.out_proj = nn.Linear(d_model, d_model)
|
|
|
|
|
| self.ffn = nn.Sequential(
|
| nn.Linear(d_model, dim_feedforward),
|
| nn.ReLU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(dim_feedforward, d_model),
|
| )
|
|
|
|
|
| self.norm1 = nn.LayerNorm(d_model)
|
| self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
|
| self.dropout1 = nn.Dropout(dropout)
|
| self.dropout2 = nn.Dropout(dropout)
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass through Transformer block.
|
|
|
| Args:
|
| x: Input tensor [batch_size, seq_len, d_model]
|
| mask: Optional attention mask
|
|
|
| Returns:
|
| Output tensor [batch_size, seq_len, d_model]
|
| """
|
|
|
| residual = x
|
| x = self.norm1(x)
|
|
|
|
|
| q = self.q_proj(x)
|
| k = self.k_proj(x)
|
| v = self.v_proj(x)
|
|
|
|
|
| batch_size, seq_len, _ = x.shape
|
| q = q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
| k = k.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
| v = v.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
|
|
|
|
| scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
|
| if mask is not None:
|
| scores = scores.masked_fill(mask == 0, -1e9)
|
| attn_weights = F.softmax(scores, dim=-1)
|
| attn_output = torch.matmul(attn_weights, v)
|
|
|
|
|
| attn_output = attn_output.transpose(1, 2).contiguous().view(
|
| batch_size, seq_len, self.d_model
|
| )
|
| attn_output = self.out_proj(attn_output)
|
| attn_output = self.dropout1(attn_output)
|
|
|
|
|
| x = residual + attn_output
|
|
|
|
|
| residual = x
|
| x = self.norm2(x)
|
| x = self.ffn(x)
|
| x = self.dropout2(x)
|
|
|
|
|
| x = residual + x
|
|
|
| return x
|
|
|
|
|
| def count_parameters(model: nn.Module) -> int:
|
| """Count total trainable parameters in a model."""
|
| return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
| def estimate_memory_mb(model: nn.Module) -> float:
|
| """Estimate memory usage of model parameters in MB."""
|
| total_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
|
| return total_bytes / (1024 ** 2)
|
|
|
|
|
| def compare_outputs(
|
| output1: torch.Tensor,
|
| output2: torch.Tensor,
|
| ) -> dict:
|
| """
|
| Compare two output tensors and compute similarity metrics.
|
|
|
| Returns:
|
| Dictionary with comparison metrics
|
| """
|
| mse = F.mse_loss(output1, output2).item()
|
| cosine_sim = F.cosine_similarity(
|
| output1.flatten(), output2.flatten(), dim=0
|
| ).item()
|
| relative_error = (
|
| torch.norm(output1 - output2) / torch.norm(output1)
|
| ).item()
|
|
|
| return {
|
| "mse": mse,
|
| "cosine_similarity": cosine_sim,
|
| "relative_error": relative_error,
|
| }
|
|
|
|
|
| def main():
|
| """Main example demonstrating BitLinear usage in Transformer."""
|
|
|
| print("=" * 80)
|
| print("BitLinear Transformer Example")
|
| print("=" * 80)
|
|
|
|
|
| batch_size = 32
|
| seq_len = 128
|
| d_model = 512
|
| nhead = 8
|
| dim_feedforward = 2048
|
|
|
|
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| print(f"\nInput shape: {x.shape}")
|
|
|
|
|
| print("\n" + "-" * 80)
|
| print("1. Standard Transformer with nn.Linear")
|
| print("-" * 80)
|
|
|
| model_standard = TransformerBlock(
|
| d_model=d_model,
|
| nhead=nhead,
|
| dim_feedforward=dim_feedforward,
|
| )
|
|
|
| print(f"Parameters: {count_parameters(model_standard):,}")
|
| print(f"Memory: {estimate_memory_mb(model_standard):.2f} MB")
|
|
|
|
|
| with torch.no_grad():
|
| output_standard = model_standard(x)
|
| print(f"Output shape: {output_standard.shape}")
|
|
|
|
|
| print("\n" + "-" * 80)
|
| print("2. Transformer with BitLinear")
|
| print("-" * 80)
|
|
|
| model_bitlinear = convert_linear_to_bitlinear(model_standard, inplace=False)
|
|
|
| print(f"Parameters: {count_parameters(model_bitlinear):,}")
|
| print(f"Memory: {estimate_memory_mb(model_bitlinear):.2f} MB")
|
|
|
|
|
| with torch.no_grad():
|
| output_bitlinear = model_bitlinear(x)
|
| print(f"Output shape: {output_bitlinear.shape}")
|
|
|
|
|
| print("\n" + "-" * 80)
|
| print("3. Output Comparison")
|
| print("-" * 80)
|
|
|
| metrics = compare_outputs(output_standard, output_bitlinear)
|
| print(f"MSE: {metrics['mse']:.6f}")
|
| print(f"Cosine similarity: {metrics['cosine_similarity']:.6f}")
|
| print(f"Relative error: {metrics['relative_error']:.6f}")
|
|
|
|
|
| print("\n" + "-" * 80)
|
| print("4. Memory Savings")
|
| print("-" * 80)
|
|
|
| mem_standard = estimate_memory_mb(model_standard)
|
| mem_bitlinear = estimate_memory_mb(model_bitlinear)
|
| savings = (mem_standard - mem_bitlinear) / mem_standard * 100
|
|
|
| print(f"Standard model: {mem_standard:.2f} MB")
|
| print(f"BitLinear model: {mem_bitlinear:.2f} MB")
|
| print(f"Memory savings: {savings:.1f}%")
|
| print(f"Compression ratio: {mem_standard / mem_bitlinear:.1f}x")
|
|
|
|
|
| print("\n" + "-" * 80)
|
| print("5. Conversion Details")
|
| print("-" * 80)
|
|
|
| def count_linear_layers(model):
|
| count = 0
|
| for module in model.modules():
|
| if isinstance(module, nn.Linear):
|
| count += 1
|
| return count
|
|
|
| def count_bitlinear_layers(model):
|
| count = 0
|
| for module in model.modules():
|
| if isinstance(module, BitLinear):
|
| count += 1
|
| return count
|
|
|
| print(f"Original Linear layers: {count_linear_layers(model_standard)}")
|
| print(f"Converted BitLinear layers: {count_bitlinear_layers(model_bitlinear)}")
|
|
|
| print("\n" + "=" * 80)
|
| print("Example complete!")
|
| print("=" * 80)
|
| print("\nKey Takeaways:")
|
| print("- BitLinear is a drop-in replacement for nn.Linear")
|
| print("- Significant memory savings (~20x for weights)")
|
| print("- Output similarity is high (cosine sim > 0.99 typically)")
|
| print("- Slight accuracy trade-off due to ternary quantization")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|