| import matplotlib.pyplot as plt |
| import time |
| import torch |
| from torch.utils.data import DataLoader |
| from torchvision import datasets, transforms |
| import numpy as np |
| import tracemalloc |
|
|
| |
| from Andromeda.model import Andromeda |
| from Andromeda.utils.stable_adamw import StableAdamWUnfused |
|
|
| torch.manual_seed(0) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(0) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class AndromedaModelTest: |
| def __init__(self): |
| self.model = Andromeda |
| self.optimizer = StableAdamWUnfused() |
| self.loss_function = torch.nn.CrossEntropyLoss() |
| self.test_input = torch.randint(0, 256, (1, 1024)).cuda() |
|
|
| def test_forward_pass(self): |
| output = self.model(self.test_input) |
| assert output.shape == (1, 1024, 64007), "Forward pass output shape mismatch" |
|
|
| def test_backward_pass(self): |
| self.optimizer.zero_grad() |
| output = self.model(self.test_input) |
| loss = self.loss_function(output, self.test_input) |
|
|
| loss.backward() |
| for name, parameter in self.model.named_parameters(): |
| assert not torch.isnan(parameter.grad().any()), f"Gradient for {name} contains NaNs" |
| assert not torch.isinf(parameter.grad().any()), f"Gradient for {name} contains Infs" |
|
|
|
|
| def test_optimizer_step(self): |
| initial_params = [param.clone() for param in self.model_parameters()] |
| output = self.model(self.test_input) |
| loss = self.loss_function(output, self.test_input) |
|
|
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
| for initial_param, param in zip(initial_params, self.model.parameters()): |
| assert not torch.equal(initial_param, param), "Model Parameters did not change after an optimizer step" |
|
|
|
|
|
|
|
|
|
|
| class SpeedMetrics: |
| def __init__(self, model): |
| self.model = model.to(device) |
|
|
| def forward_pass_time(self): |
| start_time = time.time() |
| self.model.decoder.forward(torch.randint(0, 50304, (1, 8192), device=device, dtype=torch.long))[0] |
| end_time = time.time() |
| return end_time - start_time |
| |
| def backward_pass_time(self): |
| model_input = self.model.decoder.forward(torch.randint(0, 50304, (1, 8192), device=device, dtype=torch.long))[0] |
| start_time = time.time() |
| loss = torch.nn.CrossEntropyLoss()(model_input, torch.randint(0, 50304, (1, 8192), device=device, dtype=torch.long)) |
| loss.backward() |
| end_time = time.time() |
| return end_time - start_time |
| |
| def end_to_end_latency(self): |
| start_time = time.time() |
| self.model.forward(torch.randint(0, 50304, (1, 8192), device=device, dtype=torch.long)) |
| end_time = time.time() |
| return end_time - start_time |
| |
|
|
|
|
| class ScalabilityMetrics: |
| def __init__(self, model, dataset): |
| self.model = model |
| self.dataset = dataset |
| self.dataloader = DataLoader(dataset, batch_size=32) |
|
|
| def throughput(self): |
| start_time = time.time() |
| for i, data in enumerate(self.dataloader, 0): |
| self.model.forward(data) |
| end_time = time.time() |
| return len(self.dataset) / (end_time - start_time) |
|
|
|
|
| class ConsistencyMetrics: |
| def __init__(self, model): |
| self.model = model |
|
|
| def consistency_over_time(self): |
| consistency_times = [] |
| outputs_list = [] |
| for _ in range(10): |
| start_time = time.time() |
| outputs = self.model.forward(torch.randint(0, 50304, (1, 8192))) |
| end_time = time.time() |
| consistency_times.append(end_time - start_time) |
| outputs_list.append(outputs.detach().numpy()) |
|
|
| initial_output = outputs_list[0] |
| consistency_score = 0 |
| for output in outputs_list[1:]: |
| if np.array_equal(initial_output, output): |
| consistency_score += 1 |
| consistency_score = consistency_score / len(outputs_list) * 100 |
|
|
| return consistency_times, consistency_score |
|
|
|
|
| class MemoryMetrics: |
| def __init__(self, model): |
| self.model = model |
|
|
| def memory_footprint(self): |
| tracemalloc.start() |
| self.model.forward(torch.randint(0, 50304, (1, 8192))) |
| current, peak = tracemalloc.get_traced_memory() |
| tracemalloc.stop() |
| return current, peak |
|
|
|
|
| class SequenceMetrics: |
| def __init__(self, model): |
| self.model = model |
|
|
| def sequence_length_impact(self): |
| seq_lengths = [1024, 2048, 4096, 8192] |
| seq_impact_times = [] |
| for length in seq_lengths: |
| start_time = time.time() |
| self.model.forward(torch.randint(0, 50304, (1, length))) |
| end_time = time.time() |
| seq_impact_times.append(end_time - start_time) |
| return seq_lengths, seq_impact_times |
|
|
|
|
|
|
|
|
| class FlopsBenchmark: |
| def __init__(self, model, bsz=32, d_model=1024, num_heads=8, sequence_lengths=list(range(500, 32001, 500))): |
| self.bsz = bsz |
| self.d_model = d_model |
| self.num_heads = num_heads |
| self.sequence_lengths = sequence_lengths |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.dtype=torch.float32 |
| self.model = model.to(self.device) |
|
|
| def benchmark(self): |
| time_taken = [] |
| tflops_per_s = [] |
|
|
| for seq_len in self.sequence_lengths: |
| x = torch.randn(self.bsz, seq_len, self.d_model).to(self.device).type(self.dtype) |
| torch.cuda.synchronize() |
|
|
| start = time.time() |
| self.model(x) |
| torch.cuda.synchronize() |
| elapsed = time.time() - start |
|
|
| time_taken.append(elapsed) |
| total_flops = 4 * seq_len **2 * (self.d_model // self.num_heads) * self.num_heads |
| tflops_per_s.append(total_flops / elapsed / 1e12) |
|
|
| for seq_len, elapsed, tflops in zip(self.sequence_lengths, time_taken, tflops_per_s): |
| print(f"Sequence length: {seq_len}, Time elapsed: {elapsed} s, TFLOPs/s: {tflops}") |
|
|
|
|
| |
| test_dataset = datasets.FakeData(size=1000, transform=transforms.ToTensor()) |
|
|
| |
| model = Andromeda( |
| num_tokens=50304, |
| dim=1024, |
| depth=24, |
| dim_head=128, |
| heads=8, |
| alibi_num_heads=4 |
| ) |
|
|
|
|
| |
| |
| speed_metrics = SpeedMetrics(model) |
| forward_pass_time = speed_metrics.forward_pass_time() |
| backward_pass_time = speed_metrics.backward_pass_time() |
| end_to_end_latency = speed_metrics.end_to_end_latency() |
|
|
|
|
| |
| scalability_metrics = ScalabilityMetrics(model, test_dataset) |
| throughput = scalability_metrics.throughput() |
|
|
|
|
| |
| consistency_metrics = ConsistencyMetrics(model) |
| consistency_times, consistency_score = consistency_metrics.consistency_over_time() |
|
|
|
|
| |
| memory_metrics = MemoryMetrics(model) |
| current, peak = memory_metrics.memory_footprint() |
|
|
| |
| sequence_metrics = SequenceMetrics(model) |
| seq_lengths, seq_impact_times = sequence_metrics.sequence_length_impact() |
|
|
|
|
|
|
| |
|
|
| flops_benchmark = FlopsBenchmark(model) |
| flops_benchmark.benchmark() |
|
|
| |
| fig, axs = plt.subplots(3) |
|
|
| axs[0].bar(["Forward Pass Time", "Backward Pass Time", "End-to-End Latency"], [forward_pass_time, backward_pass_time, end_to_end_latency]) |
| axs[0].set_title('Speed Metrics') |
| axs[0].set_xlabel('Metrics') |
| axs[0].set_ylabel('Time (seconds)') |
|
|
| axs[1].bar(seq_lengths, seq_impact_times) |
| axs[1].set_title('Sequence Length Impact') |
| axs[1].set_xlabel('Sequence Length') |
| axs[1].set_ylabel('Time (seconds)') |
|
|
| axs[2].plot(list(range(1, 11)), consistency_times) |
| axs[2].set_title('Consistency Over Time') |
| axs[2].set_xlabel('Run Number') |
| axs[2].set_ylabel('Time (seconds)') |
|
|
| plt.tight_layout() |
| plt.show() |
|
|
| print(f"Throughput: {throughput} instances/second") |
| print(f"Memory used: {current / 10**6}MB; Peak: {peak / 10**6}MB") |
|
|
|
|
|
|
| |
| if __name__ == "__main__": |
| model_test = AndromedaModelTest() |
| model_test.test_forward_pass() |
| model_test.test_backward_pass() |
| model_test.test_optimizer_step() |