| """
|
| Unit tests for BitLinear and MultiTernaryLinear layers.
|
|
|
| These tests are here to validate the nn.Module implementations and their compatibility with standard PyTorch workflows. Here are the following test cases:
|
|
|
| TestBitLinear (8 tests)
|
| 1. test_initialization - Verifies layer initializes with correct shapes
|
| 2. test_no_bias_initialization - Tests initialization without bias parameter
|
| 3. test_forward_shape - Validates output shape correctness
|
| 4. test_compatibility_with_nn_linear - Tests interface compatibility with nn.Linear
|
| 5. test_from_linear_conversion - Verifies conversion from nn.Linear to BitLinear
|
| 6. test_parameter_count - Validates parameter count calculation
|
| 7. test_weight_values_are_ternary - Ensures weights are in {-1, 0, +1}
|
| 8. test_gradient_flow - Tests gradient flow for QAT support
|
|
|
| TestMultiTernaryLinear (5 tests)
|
| 1. test_initialization - Verifies k-component initialization
|
| 2. test_forward_shape - Tests forward pass output shape
|
| 3. test_k_components - Validates k-component tensor shapes
|
| 4. test_from_linear_conversion - Tests conversion with k parameter
|
| 5. test_better_approximation_with_more_k - Validates error decreases with larger k
|
|
|
| TestConversionUtilities (3 tests)
|
| 1. test_convert_simple_model - Tests conversion of Sequential models
|
| 2. test_convert_nested_model - Tests conversion of nested module hierarchies
|
| 3. test_inplace_conversion - Tests in-place vs. copy conversion modes
|
|
|
| TestLayerIntegration (3 tests)
|
| 1. test_in_transformer_block - Tests BitLinear in Transformer FFN block
|
| 2. test_training_step - Validates full training loop compatibility
|
| 3. test_save_and_load - Tests model serialization and deserialization
|
|
|
| TestPerformanceComparison (2 tests - skipped)
|
| 1. test_memory_usage - Performance benchmark (run manually)
|
| 2. test_inference_speed - Performance benchmark (run manually)
|
| """
|
|
|
| import pytest
|
| import torch
|
| import torch.nn as nn
|
|
|
| from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
|
|
|
|
| class TestBitLinear:
|
| """Tests for BitLinear layer."""
|
|
|
| def test_initialization(self):
|
| """Test that layer initializes correctly."""
|
| layer = BitLinear(512, 1024)
|
| assert layer.in_features == 512
|
| assert layer.out_features == 1024
|
| assert layer.bias is not None
|
| assert layer.W_ternary.shape == (1024, 512)
|
| assert layer.gamma.shape == (1024,)
|
|
|
| def test_no_bias_initialization(self):
|
| """Test initialization without bias."""
|
| layer = BitLinear(512, 1024, bias=False)
|
| assert layer.bias is None
|
|
|
| def test_forward_shape(self):
|
| """Test forward pass produces correct output shape."""
|
| layer = BitLinear(512, 1024)
|
| x = torch.randn(32, 128, 512)
|
| output = layer(x)
|
| assert output.shape == (32, 128, 1024)
|
|
|
| def test_compatibility_with_nn_linear(self):
|
| """Test that BitLinear can replace nn.Linear in terms of interface."""
|
| linear = nn.Linear(512, 512)
|
| bitlinear = BitLinear(512, 512)
|
|
|
| x = torch.randn(32, 512)
|
| out_linear = linear(x)
|
| out_bitlinear = bitlinear(x)
|
|
|
|
|
| assert out_linear.shape == out_bitlinear.shape
|
|
|
| def test_from_linear_conversion(self):
|
| """Test converting nn.Linear to BitLinear."""
|
| linear = nn.Linear(512, 1024)
|
| bitlinear = BitLinear.from_linear(linear)
|
|
|
| assert bitlinear.in_features == 512
|
| assert bitlinear.out_features == 1024
|
|
|
|
|
| x = torch.randn(16, 512)
|
| output = bitlinear(x)
|
| assert output.shape == (16, 1024)
|
|
|
| def test_parameter_count(self):
|
| """Test that parameter count is correct."""
|
| layer = BitLinear(512, 512, bias=True)
|
|
|
| expected_params = 512*512 + 512 + 512
|
| actual_params = sum(p.numel() for p in layer.parameters())
|
| assert actual_params == expected_params
|
|
|
| def test_weight_values_are_ternary(self):
|
| """Test that stored weights are ternary {-1, 0, +1}."""
|
| layer = BitLinear(512, 512)
|
| W_ternary = layer.W_ternary
|
| unique_values = torch.unique(W_ternary)
|
| assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
|
|
|
| def test_gradient_flow(self):
|
| """Test that gradients flow correctly (for QAT)."""
|
| layer = BitLinear(256, 128)
|
| x = torch.randn(8, 256, requires_grad=True)
|
| output = layer(x)
|
| loss = output.sum()
|
| loss.backward()
|
|
|
| assert x.grad is not None
|
|
|
| assert layer.W_ternary.grad is not None
|
| assert layer.gamma.grad is not None
|
|
|
|
|
| class TestMultiTernaryLinear:
|
| """Tests for MultiTernaryLinear layer."""
|
|
|
| def test_initialization(self):
|
| """Test layer initialization with k components."""
|
| layer = MultiTernaryLinear(512, 1024, k=4)
|
| assert layer.in_features == 512
|
| assert layer.out_features == 1024
|
| assert layer.k == 4
|
| assert layer.W_ternary.shape == (4, 1024, 512)
|
| assert layer.gammas.shape == (4, 1024)
|
|
|
| def test_forward_shape(self):
|
| """Test forward pass shape."""
|
| layer = MultiTernaryLinear(512, 1024, k=4)
|
| x = torch.randn(32, 128, 512)
|
| output = layer(x)
|
| assert output.shape == (32, 128, 1024)
|
|
|
| def test_k_components(self):
|
| """Test that layer uses k ternary components."""
|
| layer = MultiTernaryLinear(512, 512, k=3)
|
| assert layer.W_ternary.shape == (3, 512, 512)
|
| assert layer.gammas.shape == (3, 512)
|
|
|
| def test_from_linear_conversion(self):
|
| """Test converting nn.Linear to MultiTernaryLinear."""
|
| linear = nn.Linear(512, 1024)
|
| multi_ternary = MultiTernaryLinear.from_linear(linear, k=4)
|
| assert multi_ternary.k == 4
|
| assert multi_ternary.in_features == 512
|
| assert multi_ternary.out_features == 1024
|
|
|
| def test_better_approximation_with_more_k(self):
|
| """Test that larger k provides better approximation of dense layer."""
|
| linear = nn.Linear(512, 512)
|
| x = torch.randn(16, 512)
|
| out_dense = linear(x)
|
|
|
|
|
| errors = []
|
| for k in [1, 2, 4]:
|
| multi_ternary = MultiTernaryLinear.from_linear(linear, k=k)
|
| out_ternary = multi_ternary(x)
|
| error = torch.norm(out_dense - out_ternary)
|
| errors.append(error)
|
|
|
|
|
| assert errors[0] > errors[1] and errors[1] > errors[2]
|
|
|
|
|
| class TestConversionUtilities:
|
| """Tests for model conversion utilities."""
|
|
|
| def test_convert_simple_model(self):
|
| """Test converting a simple Sequential model."""
|
| model = nn.Sequential(
|
| nn.Linear(512, 1024),
|
| nn.ReLU(),
|
| nn.Linear(1024, 512),
|
| )
|
|
|
| model_bitlinear = convert_linear_to_bitlinear(model, inplace=False)
|
|
|
|
|
| assert isinstance(model_bitlinear[0], BitLinear)
|
| assert isinstance(model_bitlinear[2], BitLinear)
|
| assert isinstance(model_bitlinear[1], nn.ReLU)
|
|
|
| def test_convert_nested_model(self):
|
| """Test converting a nested model with submodules."""
|
| class NestedModel(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.layer1 = nn.Linear(256, 512)
|
| self.submodule = nn.Sequential(
|
| nn.Linear(512, 512),
|
| nn.ReLU(),
|
| )
|
| self.layer2 = nn.Linear(512, 128)
|
|
|
| model = NestedModel()
|
| model_bitlinear = convert_linear_to_bitlinear(model, inplace=False)
|
|
|
|
|
| assert isinstance(model_bitlinear.layer1, BitLinear)
|
| assert isinstance(model_bitlinear.submodule[0], BitLinear)
|
| assert isinstance(model_bitlinear.layer2, BitLinear)
|
|
|
| def test_inplace_conversion(self):
|
| """Test in-place vs. copy conversion."""
|
| model = nn.Sequential(nn.Linear(256, 256))
|
|
|
|
|
| model_copy = convert_linear_to_bitlinear(model, inplace=False)
|
| assert id(model) != id(model_copy)
|
| assert isinstance(model[0], nn.Linear)
|
| assert isinstance(model_copy[0], BitLinear)
|
|
|
|
|
| model2 = nn.Sequential(nn.Linear(256, 256))
|
| model2_result = convert_linear_to_bitlinear(model2, inplace=True)
|
| assert id(model2) == id(model2_result)
|
| assert isinstance(model2[0], BitLinear)
|
|
|
|
|
| class TestLayerIntegration:
|
| """Integration tests for layers in realistic scenarios."""
|
|
|
| def test_in_transformer_block(self):
|
| """Test BitLinear in a Transformer attention block."""
|
|
|
| class TransformerFFN(nn.Module):
|
| def __init__(self, d_model=256, d_ff=1024):
|
| super().__init__()
|
| self.fc1 = BitLinear(d_model, d_ff)
|
| self.relu = nn.ReLU()
|
| self.fc2 = BitLinear(d_ff, d_model)
|
| self.dropout = nn.Dropout(0.1)
|
|
|
| def forward(self, x):
|
| return self.dropout(self.fc2(self.relu(self.fc1(x))))
|
|
|
| model = TransformerFFN()
|
|
|
|
|
| batch_size, seq_len, d_model = 8, 32, 256
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| output = model(x)
|
|
|
|
|
| assert output.shape == (batch_size, seq_len, d_model)
|
|
|
|
|
| assert set(model.fc1.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0})
|
| assert set(model.fc2.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0})
|
|
|
| def test_training_step(self):
|
| """Test that layers work in a training loop."""
|
|
|
| model = nn.Sequential(
|
| BitLinear(128, 256),
|
| nn.ReLU(),
|
| BitLinear(256, 10),
|
| )
|
|
|
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
|
| x = torch.randn(16, 128)
|
| output = model(x)
|
|
|
|
|
| target = torch.randint(0, 10, (16,))
|
| loss = nn.functional.cross_entropy(output, target)
|
|
|
|
|
| optimizer.zero_grad()
|
| loss.backward()
|
|
|
|
|
| assert model[0].W_ternary.grad is not None
|
| assert model[0].gamma.grad is not None
|
|
|
|
|
| optimizer.step()
|
|
|
|
|
| assert torch.isfinite(loss)
|
|
|
| def test_save_and_load(self):
|
| """Test saving and loading models with BitLinear layers."""
|
| import tempfile
|
| import os
|
|
|
|
|
| model = nn.Sequential(
|
| BitLinear(128, 256),
|
| nn.ReLU(),
|
| BitLinear(256, 64),
|
| )
|
|
|
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as f:
|
| temp_path = f.name
|
| torch.save(model.state_dict(), temp_path)
|
|
|
| try:
|
|
|
| model_loaded = nn.Sequential(
|
| BitLinear(128, 256),
|
| nn.ReLU(),
|
| BitLinear(256, 64),
|
| )
|
| model_loaded.load_state_dict(torch.load(temp_path))
|
|
|
|
|
| assert torch.allclose(model[0].W_ternary, model_loaded[0].W_ternary)
|
| assert torch.allclose(model[0].gamma, model_loaded[0].gamma)
|
| assert torch.allclose(model[2].W_ternary, model_loaded[2].W_ternary)
|
| assert torch.allclose(model[2].gamma, model_loaded[2].gamma)
|
|
|
|
|
| x = torch.randn(8, 128)
|
| with torch.no_grad():
|
| out1 = model(x)
|
| out2 = model_loaded(x)
|
| assert torch.allclose(out1, out2)
|
| finally:
|
|
|
| os.unlink(temp_path)
|
|
|
|
|
|
|
| class TestPerformanceComparison:
|
| """Tests comparing BitLinear to standard nn.Linear."""
|
|
|
| @pytest.mark.skip("Performance test - run manually")
|
| def test_memory_usage(self):
|
| """Compare memory usage of BitLinear vs. nn.Linear."""
|
|
|
|
|
|
|
| pass
|
|
|
| @pytest.mark.skip("Performance test - run manually")
|
| def test_inference_speed(self):
|
| """Compare inference speed (when CUDA kernels are implemented)."""
|
|
|
| pass
|
|
|