| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torchvision |
| import torchvision.transforms as transforms |
| from torchvision.models import vgg16, vgg19, googlenet, resnet18 |
| import timm |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from torchattacks import FGSM, PGD, APGD |
| import os |
| import time |
| from datetime import datetime |
| import gradio as gr |
|
|
| class LeNet(nn.Module): |
| def __init__(self): |
| super(LeNet, self).__init__() |
| self.conv1 = nn.Conv2d(1, 6, 5) |
| self.conv2 = nn.Conv2d(6, 16, 5) |
| self.fc1 = nn.Linear(16 * 4 * 4, 120) |
| self.fc2 = nn.Linear(120, 84) |
| self.fc3 = nn.Linear(84, 10) |
| self.relu = nn.ReLU() |
| self.pool = nn.MaxPool2d(2, 2) |
| |
| def forward(self, x, return_all=False): |
| outputs = [] |
| x1 = self.pool(self.relu(self.conv1(x))) |
| outputs.append(x1) |
| x2 = self.pool(self.relu(self.conv2(x1))) |
| outputs.append(x2) |
| x2_flat = x2.view(-1, 16 * 4 * 4) |
| x3 = self.relu(self.fc1(x2_flat)) |
| outputs.append(x3) |
| x4 = self.relu(self.fc2(x3)) |
| outputs.append(x4) |
| x5 = self.fc3(x4) |
| outputs.append(x5) |
| if return_all: |
| return outputs |
| else: |
| return x5 |
|
|
| def salt_pepper_noise(images, prob=0.01, device='cuda'): |
| batch_smap = torch.rand_like(images) < prob / 2 |
| pepper = torch.rand_like(images) < prob / 2 |
| noisy = images.clone() |
| noisy[batch_smap] = 1.0 |
| noisy[pepper] = 0.0 |
| return torch.clamp(noisy, 0, 1) |
|
|
| def pepper_statistical_noise(images, prob=0.01, device='cuda'): |
| pepper = torch.rand_like(images) < prob |
| noisy = images.clone() |
| noisy[pepper] = 0.0 |
| return torch.clamp(noisy, 0, 1) |
|
|
| def get_layer_outputs(model, input_tensor): |
| outputs = [] |
| def hook(module, input, output): |
| outputs.append(output) |
| hooks = [] |
| for layer in model.modules(): |
| if isinstance(layer, (nn.Conv2d, nn.Linear)): |
| hooks.append(layer.register_forward_hook(hook)) |
| model.eval() |
| with torch.no_grad(): |
| model(input_tensor) |
| for hook in hooks: |
| hook.remove() |
| return outputs |
|
|
| def compute_mvl(model, clean_images, adv_images, device='cuda'): |
| model.eval() |
| with torch.no_grad(): |
| try: |
| clean_outputs = model(clean_images, return_all=True) |
| adv_outputs = model(adv_images, return_all=True) |
| except TypeError: |
| clean_outputs = get_layer_outputs(model, clean_images) |
| adv_outputs = get_layer_outputs(model, adv_images) |
|
|
| mvl_list = [] |
| for clean_out, adv_out in zip(clean_outputs, adv_outputs): |
| if clean_out.ndim == 4: |
| diff = torch.norm(clean_out - adv_out, p=2, dim=(1,2,3)) |
| clean_norm = torch.norm(clean_out, p=2, dim=(1,2,3)) |
| else: |
| diff = torch.norm(clean_out - adv_out, p=2, dim=1) |
| clean_norm = torch.norm(clean_out, p=2, dim=1) |
| mvl = diff / (clean_norm + 1e-8) |
| mvl_list.append(mvl.mean().item()) |
| return mvl_list |
|
|
| def get_model_stats(model): |
| param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| layer_count = len([m for m in model.modules() if isinstance(m, (nn.Conv2d, nn.Linear))]) |
| return param_count, layer_count |
|
|
| def modify_model(model, model_name): |
| if model_name.startswith('VGG'): |
| model.classifier[6] = nn.Linear(4096, 10) |
| elif model_name == 'GoogLeNet': |
| model.fc = nn.Linear(1024, 10) |
| elif model_name == 'ResNet18': |
| model.fc = nn.Linear(512, 10) |
| elif model_name == 'WideResNet': |
| model.fc = nn.Linear(2048, 10) |
| elif model_name == 'DenseNet121': |
| model.classifier = nn.Linear(model.classifier.in_features, 10) |
| elif model_name == 'MobileNetV2': |
| if isinstance(model.classifier, nn.Sequential): |
| model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10) |
| else: |
| model.classifier = nn.Linear(model.classifier.in_features, 10) |
| elif model_name == 'EfficientNet-B0': |
| model.classifier = nn.Linear(model.classifier.in_features, 10) |
| return model |
|
|
| def get_models_for_dataset(dataset_name): |
| if dataset_name == 'MNIST': |
| return ['LeNet'] |
| elif dataset_name == 'CIFAR-10': |
| return [ |
| 'VGG16', 'VGG19', 'GoogLeNet', 'ResNet18', 'WideResNet', |
| 'DenseNet121', 'MobileNetV2', 'EfficientNet-B0' |
| ] |
| else: |
| return [] |
|
|
| def get_dataset_and_transform(dataset_name): |
| if dataset_name == 'MNIST': |
| transform = transforms.Compose([ |
| transforms.Resize((28, 28)), |
| transforms.Grayscale(num_output_channels=1), |
| transforms.ToTensor(), |
| transforms.Normalize((0.1307,), (0.3081,)) |
| ]) |
| dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) |
| else: |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.485, 0.456, 0.406), |
| (0.229, 0.224, 0.225)) |
| ]) |
| dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) |
| return dataset, transform |
|
|
| def initialize_model(model_name, device): |
| if model_name == 'LeNet': |
| model = LeNet() |
| elif model_name == 'VGG16': |
| model = modify_model(vgg16(weights='IMAGENET1K_V1'), model_name) |
| elif model_name == 'VGG19': |
| model = modify_model(vgg19(weights='IMAGENET1K_V1'), model_name) |
| elif model_name == 'GoogLeNet': |
| model = modify_model(googlenet(weights='IMAGENET1K_V1'), model_name) |
| elif model_name == 'ResNet18': |
| model = modify_model(resnet18(weights='IMAGENET1K_V1'), model_name) |
| elif model_name == 'WideResNet': |
| model = modify_model(timm.create_model('wide_resnet50_2', pretrained=True), model_name) |
| elif model_name == 'DenseNet121': |
| model = modify_model(timm.create_model('densenet121', pretrained=True), model_name) |
| elif model_name == 'MobileNetV2': |
| model = modify_model(timm.create_model('mobilenetv2_100', pretrained=True), model_name) |
| elif model_name == 'EfficientNet-B0': |
| model = modify_model(timm.create_model('efficientnet_b0', pretrained=True), model_name) |
| else: |
| raise ValueError(f"Unknown model {model_name}") |
| return model.to(device) |
|
|
| def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'): |
| start_time = time.time() |
| logs = ["BSM:: experiment is being started ..."] |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| logs.append(f"Loading {dataset_name} dataset...") |
| dataset, _ = get_dataset_and_transform(dataset_name) |
| testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False) |
| logs.append(f"{dataset_name} dataset loaded with {len(testloader)} batches.") |
|
|
| logs.append(f"Initializing model {model_name} on {device}...") |
| model = initialize_model(model_name, device) |
| logs.append(f"Model {model_name} initialized.") |
|
|
| param_count, layer_count = get_model_stats(model) |
| logs.append(f"Model stats: Parameters = {param_count}, Layers = {layer_count}") |
|
|
| all_attacks = { |
| 'FGSM': FGSM(model, eps=0.03), |
| 'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True), |
| 'APGD': APGD(model, eps=0.03, steps=100, loss='ce'), |
| 'Salt & Pepper': lambda x, y: salt_pepper_noise(x, prob=0.01, device=device), |
| 'Pepper Statistical': lambda x, y: pepper_statistical_noise(x, prob=0.01, device=device) |
| } |
| attacks = {name: attack for name, attack in all_attacks.items() if name in selected_attacks} |
| if not attacks: |
| logs.append("Error: No valid attacks selected") |
| return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)] |
| logs.append(f"Selected attacks: {', '.join(attacks.keys())}") |
|
|
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
| output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}") |
| os.makedirs(output_dir, exist_ok=True) |
| logs.append(f"Output directory created: {output_dir}") |
|
|
| results = {atk: {'cm': [], 'mvl': []} for atk in attacks} |
|
|
| for i, (images, labels) in enumerate(testloader): |
| if i >= num_batches: |
| logs.append(f"Reached batch limit: {num_batches}") |
| break |
| images, labels = images.to(device), labels.to(device) |
| logs.append(f"Processing batch {i+1}/{num_batches}...") |
|
|
| for atk_name, atk in attacks.items(): |
| logs.append(f" Running attack: {atk_name} on batch {i+1}") |
| adv_images = atk(images, labels) |
| mvl_vals = compute_mvl(model, images, adv_images, device) |
| results[atk_name]['mvl'].append(mvl_vals) |
| batch_cm = np.mean(mvl_vals) |
| results[atk_name]['cm'].append(batch_cm) |
| logs.append(f" Attack {atk_name}: batch CM={batch_cm:.6f}") |
|
|
| logs.append("Finished processing batches, computing statistics...") |
|
|
| cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks} |
| cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks} |
|
|
| plt.figure(figsize=(8,6)) |
| attack_names = list(attacks.keys()) |
| means = [cm_means[a] for a in attack_names] |
| stds = [cm_stds[a] for a in attack_names] |
| x = np.arange(len(attack_names)) |
| plt.bar(x, means, yerr=stds, capsize=5) |
| plt.xticks(x, attack_names, rotation=45) |
| plt.ylabel("CM (Relative Error)") |
| plt.title(f"CM for {model_name} ({dataset_name})") |
| plt.tight_layout() |
| cm_plot_path = os.path.join(output_dir, "cm_plot.png") |
| plt.savefig(cm_plot_path) |
| plt.close() |
| logs.append(f"Saved CM plot: {cm_plot_path}") |
|
|
| mvl_plot_paths = [] |
| colors = ['skyblue', 'lightgreen', 'coral', 'lightgray', 'purple'] |
| for i, atk in enumerate(attack_names): |
| mvl_arr = np.array(results[atk]['mvl']) |
| mean_vals = np.mean(mvl_arr, axis=0) |
| std_vals = np.std(mvl_arr, axis=0) |
| layers = [f"Layer {j+1}" for j in range(len(mean_vals))] |
| plt.figure(figsize=(8,6)) |
| plt.plot(layers, mean_vals, marker='o', color=colors[i % len(colors)]) |
| plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[i % len(colors)], alpha=0.3) |
| plt.title(f"MVL per Layer - {atk}") |
| plt.ylabel("MVL (Mean ± Std)") |
| plt.xticks(rotation=45) |
| plt.grid(True) |
| plt.tight_layout() |
| path = os.path.join(output_dir, f"mvl_{atk.lower().replace(' ', '_')}.png") |
| plt.savefig(path) |
| plt.close() |
| mvl_plot_paths.append(path) |
| logs.append(f"Saved MVL plot for {atk}: {path}") |
|
|
| plt.figure(figsize=(10,6)) |
| for i, atk in enumerate(attack_names): |
| mvl_arr = np.array(results[atk]['mvl']) |
| mean_vals = np.mean(mvl_arr, axis=0) |
| std_vals = np.std(mvl_arr, axis=0) |
| layers = [f"Layer {j+1}" for j in range(len(mean_vals))] |
| plt.plot(layers, mean_vals, marker='o', color=colors[i % len(colors)], label=atk) |
| plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[i % len(colors)], alpha=0.3) |
| plt.title(f"Integrated MVL - {model_name}") |
| plt.ylabel("MVL (Mean ± Std)") |
| plt.xticks(rotation=45) |
| plt.legend() |
| plt.grid(True) |
| plt.tight_layout() |
| integrated_mvl_plot_path = os.path.join(output_dir, "integrated_mvl.png") |
| plt.savefig(integrated_mvl_plot_path) |
| plt.close() |
| logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}") |
|
|
| processing_time = time.time() - start_time |
| logs.append(f"Processing completed in {processing_time:.2f} seconds") |
|
|
| stats = { |
| 'Dataset': dataset_name, |
| 'Model': model_name, |
| 'Parameters': param_count, |
| 'Layers': layer_count, |
| 'Batches': num_batches, |
| 'Attacks': ', '.join(attack_names), |
| 'Time (s)': round(processing_time, 2) |
| } |
| stats_text = "## Model Statistics\n\n| Metric | Value |\n|---|---|\n" |
| for k,v in stats.items(): |
| stats_text += f"| {k} | {v} |\n" |
|
|
| while len(mvl_plot_paths) < 5: |
| mvl_plot_paths.append(None) |
|
|
| return [ |
| None, |
| cm_plot_path, |
| *mvl_plot_paths[:5], |
| integrated_mvl_plot_path, |
| stats_text, |
| '\n'.join(logs) |
| ] |
|
|
| paper_info_html = """ |
| <div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;"> |
| <h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2> |
| <h3>Authors</h3> |
| <p>Mohammad Khalooei, Mohammad Mehdi Homayounpour, Maryam Amirmazlaghani</p> |
| |
| <h3>Abstract</h3> |
| <ul> |
| <li>The layer sustainability analysis (LSA) framework is introduced to evaluate the behavior of layer-level representations of DNNs in dealing with network input perturbations using Lipschitz theoretical concepts.</li> |
| <li>A layer-wise regularized adversarial training (AT-LR) approach significantly improves the generalization and robustness of different deep neural network architectures for significant perturbations while reducing layer-level vulnerabilities.</li> |
| <li>AT-LR loss landscapes for each LSA MVL proposal can interpret layer importance for different layers, which is an intriguing aspect.</li> |
| </ul> |
| |
| <h3>Links</h3> |
| <ul> |
| <li><a href="https://arxiv.org/abs/2202.02626" target="_blank">ArXiv Paper</a></li> |
| <li><a href="https://github.com/khalooei/LSA" target="_blank">GitHub Repository</a></li> |
| <li><a href="https://www.sciencedirect.com/science/article/abs/pii/S0925231223002928" target="_blank">ScienceDirect Article</a></li> |
| </ul> |
| </div> |
| """ |
|
|
| def update_models(dataset_name): |
| if dataset_name == 'MNIST': |
| return gr.update(visible=False), "LeNet" |
| else: |
| models = get_models_for_dataset(dataset_name) |
| return gr.update(choices=models, value=models[0], visible=True), gr.update(visible=False) |
|
|
| def create_interface(): |
| datasets = ['MNIST', 'CIFAR-10'] |
| attacks = ['FGSM', 'PGD', 'APGD', 'Salt & Pepper', 'Pepper Statistical'] |
|
|
| with gr.Blocks() as interface: |
| gr.Markdown("# Layer-wise Sustainability Analysis") |
| gr.Markdown(paper_info_html) |
|
|
| initial_input="MNIST" |
| dataset_input = gr.Dropdown(datasets, label="Select Dataset", value=initial_input) |
| model_input = gr.Dropdown(get_models_for_dataset(initial_input), value=get_models_for_dataset(initial_input)[0], label="Select Model") |
| model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model") |
|
|
| attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks) |
| batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=2, label="Number of Batches") |
| run_button = gr.Button("Run Analysis") |
|
|
| error_output = gr.Textbox(label="Error", visible=False) |
| cm_output = gr.Image(label="Comparative Measure (CM)") |
|
|
| with gr.Tabs(): |
| mvl_outputs = [] |
| for attack in attacks: |
| with gr.Tab(f"MVL: {attack}"): |
| mvl_output = gr.Image(label=f"MVL for {attack}") |
| mvl_outputs.append(mvl_output) |
| with gr.Tab("Integrated MVL"): |
| integrated_mvl_output = gr.Image(label="Integrated MVL for All Attacks") |
| with gr.Tab("Model Statistics"): |
| stats_output = gr.Markdown("## Model Statistics") |
| with gr.Tab("Logs"): |
| log_output = gr.Textbox(label="Processing Logs", lines=15, interactive=False) |
|
|
| dataset_input.change( |
| fn=update_models, |
| inputs=dataset_input, |
| outputs=[model_input, model_text] |
| ) |
|
|
| def get_model_for_mnist_or_dropdown(dataset_name, model_name): |
| return "LeNet" if dataset_name == 'MNIST' else model_name |
|
|
| def run_analysis(dataset_name, model_name, attacks, batches): |
| real_model = get_model_for_mnist_or_dropdown(dataset_name, model_name) |
| return layer_sustainability_analysis(dataset_name, real_model, attacks, batches) |
|
|
| run_button.click( |
| fn=run_analysis, |
| inputs=[dataset_input, model_input, attack_input, batch_input], |
| outputs=[error_output, cm_output] + mvl_outputs + [integrated_mvl_output, stats_output, log_output] |
| ) |
|
|
| return interface |
|
|
| if __name__ == '__main__': |
| interface = create_interface() |
| interface.launch() |
|
|