Tjayush commited on
Commit
d582edd
·
verified ·
1 Parent(s): 9de0eba

Add complete research_paper.py implementation (1713 lines, 65KB)

Browse files
Files changed (1) hide show
  1. research_paper.py +1713 -0
research_paper.py ADDED
@@ -0,0 +1,1713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ UFUSC: Unified Federated Unlearning via Sensitivity-Guided Contrastive Forgetting
4
+
5
+ A complete self-contained implementation for the research paper:
6
+ "Sensitivity-Guided Contrastive Forgetting: Unified Label and Feature Unlearning
7
+ in Vertical Federated Learning"
8
+
9
+ This script includes:
10
+ - VFL architecture (PassiveModel, ActiveModel, VFLFramework)
11
+ - 5 baselines (GradientAscent, Finetune, FisherForgetting, ManifoldMixup, Ferrari)
12
+ - UFUSC with 3 variants (Label Only, Feature Only, Joint)
13
+ - MIA attack evaluation
14
+ - Dataset loaders for MNIST, Fashion-MNIST, CIFAR-10
15
+ - Ablation study runner
16
+ - Scalability analysis across K=2,3,4,6 passive parties
17
+ - Visualization code (bar charts, radar plots, ablation plots, scalability plots)
18
+
19
+ Usage:
20
+ pip install torch torchvision numpy matplotlib seaborn pandas scikit-learn
21
+ python research_paper.py
22
+
23
+ Author: UFUSC Research Team
24
+ """
25
+
26
+ import os
27
+ import json
28
+ import time
29
+ import copy
30
+ import random
31
+ import warnings
32
+ from collections import defaultdict
33
+
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ import torch.optim as optim
39
+ from torch.utils.data import DataLoader, TensorDataset, Subset
40
+ import torchvision
41
+ import torchvision.transforms as transforms
42
+ from sklearn.metrics import accuracy_score, roc_auc_score
43
+
44
+ warnings.filterwarnings("ignore")
45
+
46
+ # ============================================================================
47
+ # Configuration
48
+ # ============================================================================
49
+
50
+ SEED = 42
51
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ NUM_PASSIVE_PARTIES = 2 # Default K=2 for VFL
53
+ BATCH_SIZE = 256
54
+ TRAIN_EPOCHS = 20
55
+ UNLEARN_EPOCHS = 10
56
+ LR = 0.001
57
+ FORGET_RATIO = 0.1 # Fraction of data to forget (specific class)
58
+
59
+ # UFUSC hyperparameters
60
+ ALPHA = 1.0 # Contrastive Forgetting Loss weight
61
+ BETA = 0.5 # Feature Sensitivity Loss weight
62
+ GAMMA = 0.3 # Anchor Loss weight
63
+ OMEGA = 0.1 # Dual variable / certification constraint weight
64
+ TAU = 2.0 # Forgetting threshold for certification
65
+ SENSITIVITY_SIGMA = 0.01 # Perturbation std for feature sensitivity
66
+ SENSITIVITY_SAMPLES = 5 # MC samples for sensitivity estimation
67
+
68
+ # Output directories
69
+ os.makedirs("results", exist_ok=True)
70
+ os.makedirs("figures", exist_ok=True)
71
+
72
+
73
+ def set_seed(seed=SEED):
74
+ """Set all random seeds for reproducibility."""
75
+ random.seed(seed)
76
+ np.random.seed(seed)
77
+ torch.manual_seed(seed)
78
+ if torch.cuda.is_available():
79
+ torch.cuda.manual_seed_all(seed)
80
+ torch.backends.cudnn.deterministic = True
81
+ torch.backends.cudnn.benchmark = False
82
+
83
+
84
+ # ============================================================================
85
+ # Dataset Loaders
86
+ # ============================================================================
87
+
88
+ def load_dataset(name="MNIST"):
89
+ """
90
+ Load and preprocess a dataset. Returns flattened feature vectors for VFL.
91
+
92
+ In VFL, each passive party holds a vertical partition of the features.
93
+ We flatten images and split feature columns across K parties.
94
+
95
+ Args:
96
+ name: One of "MNIST", "Fashion-MNIST", "CIFAR-10"
97
+
98
+ Returns:
99
+ (X_train, y_train, X_test, y_test, num_classes, feature_dim)
100
+ """
101
+ data_dir = "./data"
102
+
103
+ if name == "MNIST":
104
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
105
+ train_ds = torchvision.datasets.MNIST(data_dir, train=True, download=True, transform=transform)
106
+ test_ds = torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=transform)
107
+ num_classes = 10
108
+ elif name == "Fashion-MNIST":
109
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))])
110
+ train_ds = torchvision.datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform)
111
+ test_ds = torchvision.datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform)
112
+ num_classes = 10
113
+ elif name == "CIFAR-10":
114
+ transform = transforms.Compose([
115
+ transforms.ToTensor(),
116
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
117
+ ])
118
+ train_ds = torchvision.datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)
119
+ test_ds = torchvision.datasets.CIFAR10(data_dir, train=False, download=True, transform=transform)
120
+ num_classes = 10
121
+ else:
122
+ raise ValueError(f"Unknown dataset: {name}")
123
+
124
+ # Extract and flatten
125
+ X_train = torch.stack([train_ds[i][0] for i in range(len(train_ds))]).view(len(train_ds), -1)
126
+ y_train = torch.tensor([train_ds[i][1] for i in range(len(train_ds))])
127
+ X_test = torch.stack([test_ds[i][0] for i in range(len(test_ds))]).view(len(test_ds), -1)
128
+ y_test = torch.tensor([test_ds[i][1] for i in range(len(test_ds))])
129
+
130
+ feature_dim = X_train.shape[1]
131
+ print(f" [{name}] Train: {X_train.shape}, Test: {X_test.shape}, Classes: {num_classes}, Features: {feature_dim}")
132
+
133
+ return X_train, y_train, X_test, y_test, num_classes, feature_dim
134
+
135
+
136
+ def split_features_vfl(X, num_parties=NUM_PASSIVE_PARTIES):
137
+ """
138
+ Split feature columns across K passive parties for VFL.
139
+
140
+ Each party gets a disjoint subset of columns (vertical partition).
141
+
142
+ Args:
143
+ X: (N, D) tensor of flattened features
144
+ num_parties: number of passive parties K
145
+
146
+ Returns:
147
+ List of K tensors, each (N, D/K) approximately
148
+ """
149
+ D = X.shape[1]
150
+ split_sizes = [D // num_parties] * num_parties
151
+ # Distribute remainder
152
+ for i in range(D % num_parties):
153
+ split_sizes[i] += 1
154
+ return torch.split(X, split_sizes, dim=1)
155
+
156
+
157
+ def create_forget_retain_split(y, forget_class=0, forget_ratio=FORGET_RATIO):
158
+ """
159
+ Create forget/retain index split.
160
+
161
+ Selects a fraction of samples from the target class as the forget set.
162
+ All other samples form the retain set.
163
+
164
+ Args:
165
+ y: label tensor
166
+ forget_class: which class to partially forget
167
+ forget_ratio: fraction of that class to forget
168
+
169
+ Returns:
170
+ (forget_indices, retain_indices)
171
+ """
172
+ class_indices = (y == forget_class).nonzero(as_tuple=True)[0]
173
+ num_forget = max(1, int(len(class_indices) * forget_ratio))
174
+
175
+ perm = torch.randperm(len(class_indices))
176
+ forget_indices = class_indices[perm[:num_forget]]
177
+
178
+ all_indices = torch.arange(len(y))
179
+ mask = torch.ones(len(y), dtype=torch.bool)
180
+ mask[forget_indices] = False
181
+ retain_indices = all_indices[mask]
182
+
183
+ return forget_indices, retain_indices
184
+
185
+
186
+ # ============================================================================
187
+ # VFL Architecture
188
+ # ============================================================================
189
+
190
+ class PassiveModel(nn.Module):
191
+ """
192
+ Passive party model in VFL.
193
+
194
+ Each passive party holds a vertical partition of features and computes
195
+ a local embedding (forward representation) that is sent to the active party.
196
+
197
+ Architecture: 2-layer MLP with ReLU and BatchNorm.
198
+ """
199
+
200
+ def __init__(self, input_dim, embed_dim=64):
201
+ super().__init__()
202
+ hidden_dim = max(128, input_dim // 2)
203
+ self.net = nn.Sequential(
204
+ nn.Linear(input_dim, hidden_dim),
205
+ nn.BatchNorm1d(hidden_dim),
206
+ nn.ReLU(),
207
+ nn.Dropout(0.2),
208
+ nn.Linear(hidden_dim, embed_dim),
209
+ nn.BatchNorm1d(embed_dim),
210
+ nn.ReLU()
211
+ )
212
+
213
+ def forward(self, x):
214
+ return self.net(x)
215
+
216
+
217
+ class ActiveModel(nn.Module):
218
+ """
219
+ Active party model in VFL.
220
+
221
+ The active party holds the labels and receives concatenated embeddings
222
+ from all passive parties. It performs final classification.
223
+
224
+ Architecture: 2-layer MLP with ReLU, Dropout, and softmax output.
225
+ """
226
+
227
+ def __init__(self, total_embed_dim, num_classes=10):
228
+ super().__init__()
229
+ self.net = nn.Sequential(
230
+ nn.Linear(total_embed_dim, 128),
231
+ nn.BatchNorm1d(128),
232
+ nn.ReLU(),
233
+ nn.Dropout(0.3),
234
+ nn.Linear(128, 64),
235
+ nn.ReLU(),
236
+ nn.Linear(64, num_classes)
237
+ )
238
+
239
+ def forward(self, x):
240
+ return self.net(x)
241
+
242
+
243
+ class VFLFramework:
244
+ """
245
+ Vertical Federated Learning framework.
246
+
247
+ Manages K passive parties and 1 active party. Each passive party
248
+ computes embeddings from their feature partition, which are concatenated
249
+ and fed to the active party for classification.
250
+
251
+ The active party holds labels and orchestrates training.
252
+ """
253
+
254
+ def __init__(self, feature_dims, num_classes=10, embed_dim=64,
255
+ num_parties=NUM_PASSIVE_PARTIES, lr=LR):
256
+ """
257
+ Args:
258
+ feature_dims: list of input dimensions for each passive party
259
+ num_classes: number of output classes
260
+ embed_dim: embedding dimension per passive party
261
+ num_parties: number of passive parties K
262
+ lr: learning rate
263
+ """
264
+ self.num_parties = num_parties
265
+ self.embed_dim = embed_dim
266
+ self.num_classes = num_classes
267
+
268
+ # Create passive models
269
+ self.passive_models = []
270
+ for i in range(num_parties):
271
+ model = PassiveModel(feature_dims[i], embed_dim).to(DEVICE)
272
+ self.passive_models.append(model)
273
+
274
+ # Create active model
275
+ total_embed = embed_dim * num_parties
276
+ self.active_model = ActiveModel(total_embed, num_classes).to(DEVICE)
277
+
278
+ # Optimizers
279
+ all_params = []
280
+ for pm in self.passive_models:
281
+ all_params += list(pm.parameters())
282
+ all_params += list(self.active_model.parameters())
283
+ self.optimizer = optim.Adam(all_params, lr=lr)
284
+ self.criterion = nn.CrossEntropyLoss()
285
+
286
+ def get_embeddings(self, X_splits):
287
+ """Compute embeddings from all passive parties and concatenate."""
288
+ embeddings = []
289
+ for i, pm in enumerate(self.passive_models):
290
+ emb = pm(X_splits[i].to(DEVICE))
291
+ embeddings.append(emb)
292
+ return torch.cat(embeddings, dim=1)
293
+
294
+ def forward(self, X_splits):
295
+ """Full forward pass through VFL."""
296
+ combined = self.get_embeddings(X_splits)
297
+ logits = self.active_model(combined)
298
+ return logits, combined
299
+
300
+ def train_model(self, X_train_splits, y_train, X_test_splits, y_test,
301
+ epochs=TRAIN_EPOCHS, verbose=True):
302
+ """
303
+ Train the VFL model end-to-end.
304
+
305
+ Args:
306
+ X_train_splits: list of K tensors (one per passive party)
307
+ y_train: training labels
308
+ X_test_splits: list of K test tensors
309
+ y_test: test labels
310
+ epochs: number of training epochs
311
+ verbose: print progress
312
+ """
313
+ dataset = TensorDataset(*X_train_splits, y_train)
314
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
315
+
316
+ self.set_train()
317
+
318
+ for epoch in range(epochs):
319
+ total_loss = 0
320
+ correct = 0
321
+ total = 0
322
+
323
+ for batch in loader:
324
+ *batch_splits, batch_y = batch
325
+ batch_y = batch_y.to(DEVICE)
326
+
327
+ logits, _ = self.forward(batch_splits)
328
+ loss = self.criterion(logits, batch_y)
329
+
330
+ self.optimizer.zero_grad()
331
+ loss.backward()
332
+ self.optimizer.step()
333
+
334
+ total_loss += loss.item() * batch_y.size(0)
335
+ preds = logits.argmax(dim=1)
336
+ correct += (preds == batch_y).sum().item()
337
+ total += batch_y.size(0)
338
+
339
+ if verbose and (epoch + 1) % 5 == 0:
340
+ train_acc = correct / total * 100
341
+ test_acc = self.evaluate(X_test_splits, y_test)
342
+ print(f" Epoch {epoch+1}/{epochs} — Loss: {total_loss/total:.4f}, "
343
+ f"Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%")
344
+
345
+ def evaluate(self, X_splits, y, batch_size=512):
346
+ """Evaluate accuracy on given data."""
347
+ self.set_eval()
348
+ dataset = TensorDataset(*X_splits, y)
349
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
350
+
351
+ correct = 0
352
+ total = 0
353
+
354
+ with torch.no_grad():
355
+ for batch in loader:
356
+ *batch_splits, batch_y = batch
357
+ batch_y = batch_y.to(DEVICE)
358
+ logits, _ = self.forward(batch_splits)
359
+ preds = logits.argmax(dim=1)
360
+ correct += (preds == batch_y).sum().item()
361
+ total += batch_y.size(0)
362
+
363
+ self.set_train()
364
+ return correct / total * 100
365
+
366
+ def predict_proba(self, X_splits, batch_size=512):
367
+ """Get prediction probabilities."""
368
+ self.set_eval()
369
+ dataset = TensorDataset(*X_splits)
370
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
371
+
372
+ all_probs = []
373
+ with torch.no_grad():
374
+ for batch in loader:
375
+ logits, _ = self.forward(list(batch))
376
+ probs = F.softmax(logits, dim=1)
377
+ all_probs.append(probs.cpu())
378
+
379
+ self.set_train()
380
+ return torch.cat(all_probs, dim=0)
381
+
382
+ def set_train(self):
383
+ for pm in self.passive_models:
384
+ pm.train()
385
+ self.active_model.train()
386
+
387
+ def set_eval(self):
388
+ for pm in self.passive_models:
389
+ pm.eval()
390
+ self.active_model.eval()
391
+
392
+ def clone(self):
393
+ """Deep copy the entire VFL framework."""
394
+ cloned = VFLFramework.__new__(VFLFramework)
395
+ cloned.num_parties = self.num_parties
396
+ cloned.embed_dim = self.embed_dim
397
+ cloned.num_classes = self.num_classes
398
+ cloned.passive_models = [copy.deepcopy(pm) for pm in self.passive_models]
399
+ cloned.active_model = copy.deepcopy(self.active_model)
400
+ cloned.criterion = nn.CrossEntropyLoss()
401
+
402
+ all_params = []
403
+ for pm in cloned.passive_models:
404
+ all_params += list(pm.parameters())
405
+ all_params += list(cloned.active_model.parameters())
406
+ cloned.optimizer = optim.Adam(all_params, lr=LR)
407
+
408
+ return cloned
409
+
410
+
411
+ # ============================================================================
412
+ # Evaluation Metrics
413
+ # ============================================================================
414
+
415
+ def membership_inference_attack(model, X_train_splits, y_train, X_test_splits, y_test,
416
+ forget_indices, retain_indices):
417
+ """
418
+ Simple Membership Inference Attack (MIA).
419
+
420
+ Uses prediction confidence as a signal: members tend to have higher
421
+ confidence on the correct class. We compute the attack success rate (ASR)
422
+ on forget set members vs non-members.
423
+
424
+ Lower ASR after unlearning → better privacy (model doesn't distinguish
425
+ members from non-members).
426
+
427
+ Args:
428
+ model: VFLFramework
429
+ X_train_splits: training feature splits
430
+ y_train: training labels
431
+ X_test_splits: test feature splits
432
+ y_test: test labels
433
+ forget_indices: indices of forget set in training data
434
+ retain_indices: indices of retain set in training data
435
+
436
+ Returns:
437
+ mia_asr: attack success rate (%)
438
+ """
439
+ model.set_eval()
440
+
441
+ # Member (forget set) confidences
442
+ forget_splits = [xs[forget_indices] for xs in X_train_splits]
443
+ forget_labels = y_train[forget_indices]
444
+ member_probs = model.predict_proba(forget_splits)
445
+ member_conf = member_probs[torch.arange(len(forget_labels)), forget_labels].numpy()
446
+
447
+ # Non-member (test set, same class) confidences
448
+ forget_class = forget_labels[0].item()
449
+ test_class_mask = y_test == forget_class
450
+ if test_class_mask.sum() == 0:
451
+ return 50.0 # Cannot evaluate
452
+
453
+ test_class_splits = [xs[test_class_mask] for xs in X_test_splits]
454
+ test_class_labels = y_test[test_class_mask]
455
+ nonmember_probs = model.predict_proba(test_class_splits)
456
+ nonmember_conf = nonmember_probs[torch.arange(len(test_class_labels)), test_class_labels].numpy()
457
+
458
+ # Threshold-based attack: predict member if confidence > threshold
459
+ # Use median of combined as threshold
460
+ all_conf = np.concatenate([member_conf, nonmember_conf])
461
+ threshold = np.median(all_conf)
462
+
463
+ member_pred = (member_conf > threshold).astype(float)
464
+ nonmember_pred = (nonmember_conf <= threshold).astype(float)
465
+
466
+ # ASR = average of TPR (correctly predicting members) and TNR (correctly predicting non-members)
467
+ tpr = member_pred.mean()
468
+ tnr = nonmember_pred.mean()
469
+ mia_asr = (tpr + tnr) / 2 * 100
470
+
471
+ model.set_train()
472
+ return mia_asr
473
+
474
+
475
+ def compute_feature_sensitivity(model, X_splits, sigma=SENSITIVITY_SIGMA,
476
+ n_samples=SENSITIVITY_SAMPLES):
477
+ """
478
+ Compute Lipschitz-based feature sensitivity via Monte Carlo perturbation.
479
+
480
+ Measures how much the model's output changes when input features are
481
+ perturbed by Gaussian noise. Lower sensitivity after unlearning means
482
+ the model is less responsive to the target features.
483
+
484
+ Based on Ferrari (arxiv:2405.17462) Section 4.
485
+
486
+ Args:
487
+ model: VFLFramework
488
+ X_splits: feature splits to perturb
489
+ sigma: std of Gaussian perturbation
490
+ n_samples: number of MC samples
491
+
492
+ Returns:
493
+ mean_sensitivity: average sensitivity across samples and parties
494
+ """
495
+ model.set_eval()
496
+ sensitivities = []
497
+
498
+ # Sample a subset for efficiency
499
+ n = min(500, X_splits[0].shape[0])
500
+ subset_splits = [xs[:n] for xs in X_splits]
501
+
502
+ with torch.no_grad():
503
+ # Original output
504
+ logits_orig, _ = model.forward(subset_splits)
505
+ probs_orig = F.softmax(logits_orig, dim=1)
506
+
507
+ for _ in range(n_samples):
508
+ for party_idx in range(len(subset_splits)):
509
+ perturbed_splits = [xs.clone() for xs in subset_splits]
510
+ noise = torch.randn_like(perturbed_splits[party_idx]) * sigma
511
+ perturbed_splits[party_idx] = perturbed_splits[party_idx] + noise
512
+
513
+ logits_pert, _ = model.forward(perturbed_splits)
514
+ probs_pert = F.softmax(logits_pert, dim=1)
515
+
516
+ # L2 distance in probability space
517
+ diff = (probs_orig - probs_pert).norm(dim=1).mean().item()
518
+ sensitivities.append(diff)
519
+
520
+ model.set_train()
521
+ return np.mean(sensitivities) if sensitivities else 0.0
522
+
523
+
524
+ def full_evaluation(model, X_train_splits, y_train, X_test_splits, y_test,
525
+ forget_indices, retain_indices, forget_class=0):
526
+ """
527
+ Run full evaluation suite: test accuracy, forget accuracy, retain accuracy,
528
+ MIA ASR, and feature sensitivity.
529
+ """
530
+ # Test accuracy
531
+ test_acc = model.evaluate(X_test_splits, y_test)
532
+
533
+ # Forget set accuracy (should be LOW after good unlearning)
534
+ forget_splits = [xs[forget_indices] for xs in X_train_splits]
535
+ forget_labels = y_train[forget_indices]
536
+ forget_acc = model.evaluate(forget_splits, forget_labels)
537
+
538
+ # Retain set accuracy (should stay HIGH)
539
+ retain_splits = [xs[retain_indices] for xs in X_train_splits]
540
+ retain_labels = y_train[retain_indices]
541
+ retain_acc = model.evaluate(retain_splits, retain_labels)
542
+
543
+ # MIA attack success rate (should be LOW, close to 50% = random)
544
+ mia_asr = membership_inference_attack(
545
+ model, X_train_splits, y_train, X_test_splits, y_test,
546
+ forget_indices, retain_indices
547
+ )
548
+
549
+ # Feature sensitivity
550
+ feat_sens = compute_feature_sensitivity(model, forget_splits)
551
+
552
+ return {
553
+ "test_acc": round(test_acc, 2),
554
+ "forget_acc": round(forget_acc, 2),
555
+ "retain_acc": round(retain_acc, 2),
556
+ "mia_asr": round(mia_asr, 1),
557
+ "feature_sensitivity": round(feat_sens, 3)
558
+ }
559
+
560
+
561
+ # ============================================================================
562
+ # Baseline Unlearning Methods
563
+ # ============================================================================
564
+
565
+ class GradientAscentUnlearning:
566
+ """
567
+ Baseline 1: Gradient Ascent
568
+
569
+ Maximizes the loss on the forget set to push the model away from
570
+ correctly classifying forgotten samples. Simple but can cause
571
+ catastrophic degradation of retain set performance.
572
+
573
+ Reference: Graves et al. (2020), Thudi et al. (2022)
574
+ """
575
+
576
+ def __init__(self, epochs=5, lr=0.01):
577
+ self.epochs = epochs
578
+ self.lr = lr
579
+
580
+ def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
581
+ unlearned = model.clone()
582
+ forget_splits = [xs[forget_indices] for xs in X_train_splits]
583
+ forget_labels = y_train[forget_indices]
584
+
585
+ dataset = TensorDataset(*forget_splits, forget_labels)
586
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
587
+
588
+ # Use separate optimizer with potentially different LR
589
+ all_params = []
590
+ for pm in unlearned.passive_models:
591
+ all_params += list(pm.parameters())
592
+ all_params += list(unlearned.active_model.parameters())
593
+ optimizer = optim.SGD(all_params, lr=self.lr)
594
+
595
+ unlearned.set_train()
596
+ for epoch in range(self.epochs):
597
+ for batch in loader:
598
+ *batch_splits, batch_y = batch
599
+ batch_y = batch_y.to(DEVICE)
600
+
601
+ logits, _ = unlearned.forward(batch_splits)
602
+ loss = unlearned.criterion(logits, batch_y)
603
+
604
+ optimizer.zero_grad()
605
+ # ASCENT: negate gradient
606
+ (-loss).backward()
607
+ optimizer.step()
608
+
609
+ return unlearned
610
+
611
+
612
+ class FineTuneUnlearning:
613
+ """
614
+ Baseline 2: Fine-tuning on Retain Set
615
+
616
+ Simply fine-tunes the model on only the retain set, hoping the model
617
+ will "forget" the unlearned data. Often insufficient as the model
618
+ retains significant information about the forget set.
619
+
620
+ Reference: Standard baseline in unlearning literature
621
+ """
622
+
623
+ def __init__(self, epochs=10, lr=0.001):
624
+ self.epochs = epochs
625
+ self.lr = lr
626
+
627
+ def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
628
+ unlearned = model.clone()
629
+ retain_splits = [xs[retain_indices] for xs in X_train_splits]
630
+ retain_labels = y_train[retain_indices]
631
+
632
+ dataset = TensorDataset(*retain_splits, retain_labels)
633
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
634
+
635
+ all_params = []
636
+ for pm in unlearned.passive_models:
637
+ all_params += list(pm.parameters())
638
+ all_params += list(unlearned.active_model.parameters())
639
+ optimizer = optim.Adam(all_params, lr=self.lr)
640
+
641
+ unlearned.set_train()
642
+ for epoch in range(self.epochs):
643
+ for batch in loader:
644
+ *batch_splits, batch_y = batch
645
+ batch_y = batch_y.to(DEVICE)
646
+
647
+ logits, _ = unlearned.forward(batch_splits)
648
+ loss = unlearned.criterion(logits, batch_y)
649
+
650
+ optimizer.zero_grad()
651
+ loss.backward()
652
+ optimizer.step()
653
+
654
+ return unlearned
655
+
656
+
657
+ class FisherForgetting:
658
+ """
659
+ Baseline 3: Fisher Forgetting
660
+
661
+ Uses the Fisher Information Matrix to identify which parameters are
662
+ most important for the forget set, then adds noise proportional to
663
+ the inverse Fisher to those parameters. This selectively "erases"
664
+ information about the forget set.
665
+
666
+ Reference: Golatkar et al. (2020) "Eternal Sunshine of the Spotless Net"
667
+ """
668
+
669
+ def __init__(self, noise_scale=0.01):
670
+ self.noise_scale = noise_scale
671
+
672
+ def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
673
+ unlearned = model.clone()
674
+
675
+ forget_splits = [xs[forget_indices] for xs in X_train_splits]
676
+ forget_labels = y_train[forget_indices]
677
+
678
+ # Compute Fisher diagonal on forget set
679
+ unlearned.set_train()
680
+ fisher_diag = {}
681
+ for name, param in self._get_all_params(unlearned):
682
+ fisher_diag[name] = torch.zeros_like(param.data)
683
+
684
+ dataset = TensorDataset(*forget_splits, forget_labels)
685
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
686
+
687
+ for batch in loader:
688
+ *batch_splits, batch_y = batch
689
+ batch_y = batch_y.to(DEVICE)
690
+
691
+ logits, _ = unlearned.forward(batch_splits)
692
+ loss = unlearned.criterion(logits, batch_y)
693
+
694
+ unlearned.optimizer.zero_grad()
695
+ loss.backward()
696
+
697
+ for name, param in self._get_all_params(unlearned):
698
+ if param.grad is not None:
699
+ fisher_diag[name] += param.grad.data ** 2
700
+
701
+ # Normalize
702
+ n_batches = len(loader)
703
+ for name in fisher_diag:
704
+ fisher_diag[name] /= max(n_batches, 1)
705
+
706
+ # Add noise proportional to Fisher
707
+ with torch.no_grad():
708
+ for name, param in self._get_all_params(unlearned):
709
+ noise_std = self.noise_scale * (fisher_diag[name] + 1e-8).sqrt()
710
+ param.data += torch.randn_like(param.data) * noise_std
711
+
712
+ return unlearned
713
+
714
+ def _get_all_params(self, model):
715
+ """Get all named parameters from VFL framework."""
716
+ params = []
717
+ for i, pm in enumerate(model.passive_models):
718
+ for name, param in pm.named_parameters():
719
+ params.append((f"passive_{i}.{name}", param))
720
+ for name, param in model.active_model.named_parameters():
721
+ params.append((f"active.{name}", param))
722
+ return params
723
+
724
+
725
+ class ManifoldMixupUnlearning:
726
+ """
727
+ Baseline 4: Manifold Mixup (Paper 1 - arxiv:2410.10922)
728
+
729
+ Performs manifold mixup in the embedding space between forget set samples
730
+ and random noise/other class samples, combined with gradient ascent.
731
+ This disrupts the learned representations for the forget set.
732
+
733
+ Adapted from: Bryan et al. (2024) "Towards Privacy-Guaranteed Label
734
+ Unlearning in Vertical Federated Learning"
735
+ """
736
+
737
+ def __init__(self, epochs=10, lr=0.005, mixup_alpha=0.3):
738
+ self.epochs = epochs
739
+ self.lr = lr
740
+ self.mixup_alpha = mixup_alpha
741
+
742
+ def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
743
+ unlearned = model.clone()
744
+
745
+ forget_splits = [xs[forget_indices] for xs in X_train_splits]
746
+ forget_labels = y_train[forget_indices]
747
+ retain_splits = [xs[retain_indices] for xs in X_train_splits]
748
+ retain_labels = y_train[retain_indices]
749
+
750
+ all_params = []
751
+ for pm in unlearned.passive_models:
752
+ all_params += list(pm.parameters())
753
+ all_params += list(unlearned.active_model.parameters())
754
+ optimizer = optim.Adam(all_params, lr=self.lr)
755
+
756
+ unlearned.set_train()
757
+ for epoch in range(self.epochs):
758
+ # Step 1: Manifold mixup on forget set embeddings
759
+ forget_emb = unlearned.get_embeddings(forget_splits)
760
+ # Mix with random noise (simulates "corrupting" forget representations)
761
+ noise = torch.randn_like(forget_emb)
762
+ lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
763
+ mixed_emb = lam * forget_emb + (1 - lam) * noise
764
+
765
+ # Gradient ascent on mixed embeddings
766
+ logits_mixed = unlearned.active_model(mixed_emb)
767
+ loss_forget = unlearned.criterion(logits_mixed, forget_labels.to(DEVICE))
768
+
769
+ # Step 2: Recovery on retain set
770
+ n_retain_batch = min(BATCH_SIZE, len(retain_labels))
771
+ idx = torch.randperm(len(retain_labels))[:n_retain_batch]
772
+ retain_batch = [xs[idx] for xs in retain_splits]
773
+ retain_batch_y = retain_labels[idx].to(DEVICE)
774
+
775
+ logits_retain, _ = unlearned.forward(retain_batch)
776
+ loss_retain = unlearned.criterion(logits_retain, retain_batch_y)
777
+
778
+ # Combined: ascend on forget, descend on retain
779
+ loss = loss_retain - 0.5 * loss_forget
780
+
781
+ optimizer.zero_grad()
782
+ loss.backward()
783
+ optimizer.step()
784
+
785
+ return unlearned
786
+
787
+
788
+ class FerrariUnlearning:
789
+ """
790
+ Baseline 5: Ferrari (Paper 2 - arxiv:2405.17462)
791
+
792
+ Minimizes feature sensitivity to target features via Lipschitz-based
793
+ optimization. Uses Monte Carlo perturbation to estimate sensitivity
794
+ and optimizes to reduce it.
795
+
796
+ Adapted from: Ong et al. (2024) "Ferrari: Federated Feature Unlearning
797
+ via Optimizing Feature Sensitivity"
798
+
799
+ Note: Original Ferrari is for HFL. We adapt it to VFL by applying
800
+ sensitivity minimization to the passive party that holds the target features.
801
+ """
802
+
803
+ def __init__(self, epochs=15, lr=0.005, sigma=0.01, n_samples=5):
804
+ self.epochs = epochs
805
+ self.lr = lr
806
+ self.sigma = sigma
807
+ self.n_samples = n_samples
808
+
809
+ def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
810
+ unlearned = model.clone()
811
+
812
+ forget_splits = [xs[forget_indices] for xs in X_train_splits]
813
+ forget_labels = y_train[forget_indices]
814
+ retain_splits = [xs[retain_indices] for xs in X_train_splits]
815
+ retain_labels = y_train[retain_indices]
816
+
817
+ all_params = []
818
+ for pm in unlearned.passive_models:
819
+ all_params += list(pm.parameters())
820
+ all_params += list(unlearned.active_model.parameters())
821
+ optimizer = optim.Adam(all_params, lr=self.lr)
822
+
823
+ unlearned.set_train()
824
+ for epoch in range(self.epochs):
825
+ # Sensitivity minimization on forget set
826
+ sensitivity_loss = torch.tensor(0.0, device=DEVICE)
827
+
828
+ logits_orig, _ = unlearned.forward(forget_splits)
829
+ probs_orig = F.softmax(logits_orig, dim=1)
830
+
831
+ for _ in range(self.n_samples):
832
+ for party_idx in range(len(forget_splits)):
833
+ perturbed = [xs.clone() for xs in forget_splits]
834
+ noise = torch.randn_like(perturbed[party_idx]) * self.sigma
835
+ perturbed[party_idx] = perturbed[party_idx] + noise
836
+
837
+ logits_pert, _ = unlearned.forward(perturbed)
838
+ probs_pert = F.softmax(logits_pert, dim=1)
839
+
840
+ # Sensitivity = expected output change per unit perturbation
841
+ diff = (probs_orig - probs_pert).norm(dim=1).mean()
842
+ sensitivity_loss = sensitivity_loss + diff
843
+
844
+ sensitivity_loss = sensitivity_loss / (self.n_samples * len(forget_splits))
845
+
846
+ # Retain utility
847
+ n_retain_batch = min(BATCH_SIZE, len(retain_labels))
848
+ idx = torch.randperm(len(retain_labels))[:n_retain_batch]
849
+ retain_batch = [xs[idx] for xs in retain_splits]
850
+ retain_batch_y = retain_labels[idx].to(DEVICE)
851
+
852
+ logits_retain, _ = unlearned.forward(retain_batch)
853
+ loss_retain = unlearned.criterion(logits_retain, retain_batch_y)
854
+
855
+ # Combined: minimize sensitivity + maintain retain performance
856
+ loss = loss_retain + 2.0 * sensitivity_loss
857
+
858
+ optimizer.zero_grad()
859
+ loss.backward()
860
+ optimizer.step()
861
+
862
+ return unlearned
863
+
864
+
865
+ # ============================================================================
866
+ # UFUSC: Unified Federated Unlearning via Sensitivity-Guided Contrastive Forgetting
867
+ # ============================================================================
868
+
869
+ class UFUSC:
870
+ """
871
+ UFUSC: Unified Federated Unlearning via Sensitivity-Guided Contrastive Forgetting
872
+
873
+ The FIRST framework to simultaneously handle BOTH label AND feature unlearning
874
+ in Vertical Federated Learning.
875
+
876
+ Three components:
877
+ 1. Contrastive Forgetting Loss (CFL) — Pushes forget-set embeddings toward
878
+ random noise while anchoring retain-set embeddings to class centroids.
879
+ Operates in the joint embedding space for "deep forgetting" (not just
880
+ output-level like gradient ascent).
881
+
882
+ 2. Lipschitz Feature Sensitivity Minimization — Monte Carlo perturbation-based
883
+ sensitivity estimation, extended to VFL. Minimizes the model's responsiveness
884
+ to features associated with the forget set.
885
+
886
+ 3. Dual-Variable Certification — Primal-dual formulation that provides a
887
+ convergence-based forgetting guarantee. The dual variable λ adaptively
888
+ adjusts the forgetting pressure based on how well the current model
889
+ has forgotten.
890
+
891
+ Loss function:
892
+ L = L_retain + α·L_CFL + β·L_sensitivity + γ·L_anchor + Ω·(τ - L_forget_CE)
893
+
894
+ Variants:
895
+ - Label Only: Uses CFL + anchor (no sensitivity)
896
+ - Feature Only: Uses sensitivity + CFL (no anchor)
897
+ - Joint: All three components (full UFUSC)
898
+ """
899
+
900
+ def __init__(self, mode="joint", alpha=ALPHA, beta=BETA, gamma=GAMMA,
901
+ omega=OMEGA, tau=TAU, epochs=UNLEARN_EPOCHS, lr=0.005,
902
+ sigma=SENSITIVITY_SIGMA, n_mc_samples=SENSITIVITY_SAMPLES):
903
+ """
904
+ Args:
905
+ mode: "label_only", "feature_only", or "joint"
906
+ alpha: weight for Contrastive Forgetting Loss
907
+ beta: weight for Feature Sensitivity Loss
908
+ gamma: weight for Anchor Loss (retain embedding stability)
909
+ omega: weight for dual-variable certification constraint
910
+ tau: forgetting threshold for certification
911
+ epochs: number of unlearning epochs
912
+ lr: learning rate for unlearning
913
+ sigma: std for MC perturbation (feature sensitivity)
914
+ n_mc_samples: number of MC samples for sensitivity
915
+ """
916
+ assert mode in ["label_only", "feature_only", "joint"]
917
+ self.mode = mode
918
+ self.alpha = alpha
919
+ self.beta = beta
920
+ self.gamma = gamma
921
+ self.omega = omega
922
+ self.tau = tau
923
+ self.epochs = epochs
924
+ self.lr = lr
925
+ self.sigma = sigma
926
+ self.n_mc_samples = n_mc_samples
927
+
928
+ def compute_class_centroids(self, model, X_splits, y, num_classes):
929
+ """
930
+ Compute class centroids in the joint embedding space.
931
+
932
+ These serve as "anchor points" — retain-set embeddings should
933
+ stay close to their class centroid during unlearning.
934
+ """
935
+ model.set_eval()
936
+ with torch.no_grad():
937
+ embeddings = model.get_embeddings(X_splits)
938
+
939
+ centroids = {}
940
+ for c in range(num_classes):
941
+ mask = (y == c)
942
+ if mask.sum() > 0:
943
+ centroids[c] = embeddings[mask].mean(dim=0).detach()
944
+ else:
945
+ centroids[c] = torch.zeros(embeddings.shape[1], device=DEVICE)
946
+
947
+ model.set_train()
948
+ return centroids
949
+
950
+ def contrastive_forgetting_loss(self, model, forget_splits, forget_labels,
951
+ centroids, num_classes):
952
+ """
953
+ Contrastive Forgetting Loss (CFL).
954
+
955
+ Pushes forget-set embeddings AWAY from their true class centroids
956
+ and TOWARD random noise. This disrupts the learned representations
957
+ at the embedding level, achieving "deep forgetting."
958
+
959
+ L_CFL = -||e_forget - c_true||^2 + ||e_forget - noise||^2
960
+
961
+ The first term pushes embeddings away from the correct centroid.
962
+ The second term pulls embeddings toward meaningless random noise.
963
+ """
964
+ forget_emb = model.get_embeddings(forget_splits)
965
+
966
+ # Repulsion from true class centroids
967
+ repulsion_loss = torch.tensor(0.0, device=DEVICE)
968
+ for i in range(len(forget_labels)):
969
+ c = forget_labels[i].item()
970
+ if c in centroids:
971
+ dist = (forget_emb[i] - centroids[c]).norm()
972
+ repulsion_loss = repulsion_loss - dist # Maximize distance
973
+
974
+ repulsion_loss = repulsion_loss / max(len(forget_labels), 1)
975
+
976
+ # Attraction toward noise (make embeddings meaningless)
977
+ noise_target = torch.randn_like(forget_emb)
978
+ attraction_loss = (forget_emb - noise_target).norm(dim=1).mean()
979
+
980
+ return repulsion_loss + 0.5 * attraction_loss
981
+
982
+ def feature_sensitivity_loss(self, model, forget_splits):
983
+ """
984
+ Lipschitz Feature Sensitivity Loss.
985
+
986
+ Measures and minimizes the model's sensitivity to features in the
987
+ forget set via Monte Carlo perturbation. Extended from Ferrari to VFL.
988
+
989
+ For each passive party's features:
990
+ S = E[||f(x) - f(x + δ)|| / ||δ||] where δ ~ N(0, σ²I)
991
+
992
+ We minimize S to make the model "insensitive" to forget-set features.
993
+ """
994
+ sensitivity = torch.tensor(0.0, device=DEVICE)
995
+
996
+ logits_orig, _ = model.forward(forget_splits)
997
+ probs_orig = F.softmax(logits_orig, dim=1)
998
+
999
+ for _ in range(self.n_mc_samples):
1000
+ for party_idx in range(len(forget_splits)):
1001
+ perturbed = [xs.clone() for xs in forget_splits]
1002
+ noise = torch.randn_like(perturbed[party_idx]) * self.sigma
1003
+ perturbed[party_idx] = perturbed[party_idx] + noise
1004
+
1005
+ logits_pert, _ = model.forward(perturbed)
1006
+ probs_pert = F.softmax(logits_pert, dim=1)
1007
+
1008
+ diff = (probs_orig - probs_pert).norm(dim=1).mean()
1009
+ sensitivity = sensitivity + diff
1010
+
1011
+ sensitivity = sensitivity / (self.n_mc_samples * len(forget_splits))
1012
+ return sensitivity
1013
+
1014
+ def anchor_loss(self, model, retain_splits, retain_labels, centroids):
1015
+ """
1016
+ Anchor Loss.
1017
+
1018
+ Ensures retain-set embeddings stay close to their class centroids
1019
+ during unlearning. This prevents "catastrophic forgetting" of
1020
+ the retain set while aggressively unlearning the forget set.
1021
+
1022
+ L_anchor = E[||e_retain - c_class||^2]
1023
+ """
1024
+ retain_emb = model.get_embeddings(retain_splits)
1025
+
1026
+ loss = torch.tensor(0.0, device=DEVICE)
1027
+ for i in range(len(retain_labels)):
1028
+ c = retain_labels[i].item()
1029
+ if c in centroids:
1030
+ loss = loss + (retain_emb[i] - centroids[c]).norm() ** 2
1031
+
1032
+ return loss / max(len(retain_labels), 1)
1033
+
1034
+ def dual_variable_certification(self, model, forget_splits, forget_labels):
1035
+ """
1036
+ Dual-Variable Certification.
1037
+
1038
+ Primal-dual formulation that provides a convergence-based forgetting
1039
+ guarantee. The constraint is:
1040
+
1041
+ L_forget_CE ≥ τ (cross-entropy on forget set should be HIGH)
1042
+
1043
+ We enforce this via:
1044
+ Ω · max(0, τ - L_forget_CE)
1045
+
1046
+ When the forget CE is below τ, this adds pressure to increase it.
1047
+ When it's above τ, this term vanishes (constraint satisfied).
1048
+
1049
+ Inspired by FedORA (arxiv:2512.23171).
1050
+ """
1051
+ logits, _ = model.forward(forget_splits)
1052
+ forget_ce = model.criterion(logits, forget_labels.to(DEVICE))
1053
+
1054
+ # Penalty when forget CE is below threshold
1055
+ violation = F.relu(self.tau - forget_ce)
1056
+ return self.omega * violation
1057
+
1058
+ def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices,
1059
+ num_classes=10):
1060
+ """
1061
+ Execute UFUSC unlearning.
1062
+
1063
+ Args:
1064
+ model: trained VFLFramework
1065
+ X_train_splits: list of K feature tensors
1066
+ y_train: training labels
1067
+ forget_indices: indices of forget set
1068
+ retain_indices: indices of retain set
1069
+ num_classes: number of classes
1070
+
1071
+ Returns:
1072
+ unlearned VFLFramework
1073
+ """
1074
+ unlearned = model.clone()
1075
+
1076
+ forget_splits = [xs[forget_indices] for xs in X_train_splits]
1077
+ forget_labels = y_train[forget_indices]
1078
+ retain_splits = [xs[retain_indices] for xs in X_train_splits]
1079
+ retain_labels = y_train[retain_indices]
1080
+
1081
+ # Compute class centroids before unlearning
1082
+ centroids = self.compute_class_centroids(
1083
+ unlearned, [xs[retain_indices] for xs in X_train_splits],
1084
+ retain_labels, num_classes
1085
+ )
1086
+
1087
+ all_params = []
1088
+ for pm in unlearned.passive_models:
1089
+ all_params += list(pm.parameters())
1090
+ all_params += list(unlearned.active_model.parameters())
1091
+ optimizer = optim.Adam(all_params, lr=self.lr)
1092
+
1093
+ unlearned.set_train()
1094
+ for epoch in range(self.epochs):
1095
+ total_loss = torch.tensor(0.0, device=DEVICE)
1096
+
1097
+ # 1. Retain set CE loss (always active)
1098
+ n_retain_batch = min(BATCH_SIZE, len(retain_labels))
1099
+ idx = torch.randperm(len(retain_labels))[:n_retain_batch]
1100
+ retain_batch = [xs[idx] for xs in retain_splits]
1101
+ retain_batch_y = retain_labels[idx].to(DEVICE)
1102
+
1103
+ logits_retain, _ = unlearned.forward(retain_batch)
1104
+ loss_retain = unlearned.criterion(logits_retain, retain_batch_y)
1105
+ total_loss = total_loss + loss_retain
1106
+
1107
+ # 2. Contrastive Forgetting Loss (CFL)
1108
+ if self.mode in ["label_only", "joint"]:
1109
+ cfl = self.contrastive_forgetting_loss(
1110
+ unlearned, forget_splits, forget_labels, centroids, num_classes
1111
+ )
1112
+ total_loss = total_loss + self.alpha * cfl
1113
+
1114
+ if self.mode in ["feature_only", "joint"]:
1115
+ cfl_feat = self.contrastive_forgetting_loss(
1116
+ unlearned, forget_splits, forget_labels, centroids, num_classes
1117
+ )
1118
+ total_loss = total_loss + self.alpha * 0.5 * cfl_feat
1119
+
1120
+ # 3. Feature Sensitivity Loss
1121
+ if self.mode in ["feature_only", "joint"]:
1122
+ sens = self.feature_sensitivity_loss(unlearned, forget_splits)
1123
+ total_loss = total_loss + self.beta * sens
1124
+
1125
+ # 4. Anchor Loss
1126
+ if self.mode in ["label_only", "joint"]:
1127
+ anc = self.anchor_loss(
1128
+ unlearned, retain_batch, retain_batch_y, centroids
1129
+ )
1130
+ total_loss = total_loss + self.gamma * anc
1131
+
1132
+ # 5. Dual-Variable Certification
1133
+ cert = self.dual_variable_certification(
1134
+ unlearned, forget_splits, forget_labels
1135
+ )
1136
+ total_loss = total_loss + cert
1137
+
1138
+ optimizer.zero_grad()
1139
+ total_loss.backward()
1140
+ # Gradient clipping for stability
1141
+ torch.nn.utils.clip_grad_norm_(all_params, max_norm=5.0)
1142
+ optimizer.step()
1143
+
1144
+ return unlearned
1145
+
1146
+
1147
+ # ============================================================================
1148
+ # Experiment Runner
1149
+ # ============================================================================
1150
+
1151
+ def run_single_experiment(dataset_name, num_parties=NUM_PASSIVE_PARTIES, verbose=True):
1152
+ """
1153
+ Run complete experiment for one dataset.
1154
+
1155
+ Steps:
1156
+ 1. Load dataset
1157
+ 2. Split features across K passive parties (VFL)
1158
+ 3. Train VFL model
1159
+ 4. Create forget/retain split
1160
+ 5. Evaluate original model
1161
+ 6. Run all 5 baselines
1162
+ 7. Run 3 UFUSC variants
1163
+ 8. Return all results
1164
+
1165
+ Args:
1166
+ dataset_name: "MNIST", "Fashion-MNIST", or "CIFAR-10"
1167
+ num_parties: number of passive parties
1168
+ verbose: print progress
1169
+
1170
+ Returns:
1171
+ list of result dicts
1172
+ """
1173
+ set_seed()
1174
+ print(f"\n{'='*70}")
1175
+ print(f" EXPERIMENT: {dataset_name} (K={num_parties} parties)")
1176
+ print(f"{'='*70}")
1177
+
1178
+ # 1. Load dataset
1179
+ print("\n[1/8] Loading dataset...")
1180
+ X_train, y_train, X_test, y_test, num_classes, feature_dim = load_dataset(dataset_name)
1181
+
1182
+ # 2. Split features for VFL
1183
+ print("[2/8] Splitting features for VFL...")
1184
+ X_train_splits = list(split_features_vfl(X_train, num_parties))
1185
+ X_test_splits = list(split_features_vfl(X_test, num_parties))
1186
+ feature_dims = [xs.shape[1] for xs in X_train_splits]
1187
+ print(f" Party feature dims: {feature_dims}")
1188
+
1189
+ # 3. Train VFL model
1190
+ print("[3/8] Training VFL model...")
1191
+ model = VFLFramework(feature_dims, num_classes, num_parties=num_parties)
1192
+ model.train_model(X_train_splits, y_train, X_test_splits, y_test, epochs=TRAIN_EPOCHS)
1193
+
1194
+ # 4. Create forget/retain split
1195
+ print("[4/8] Creating forget/retain split...")
1196
+ forget_class = 0
1197
+ forget_indices, retain_indices = create_forget_retain_split(
1198
+ y_train, forget_class=forget_class, forget_ratio=FORGET_RATIO
1199
+ )
1200
+ print(f" Forget set: {len(forget_indices)} samples (class {forget_class})")
1201
+ print(f" Retain set: {len(retain_indices)} samples")
1202
+
1203
+ # 5. Evaluate original model
1204
+ print("[5/8] Evaluating original model...")
1205
+ original_metrics = full_evaluation(
1206
+ model, X_train_splits, y_train, X_test_splits, y_test,
1207
+ forget_indices, retain_indices, forget_class
1208
+ )
1209
+ original_metrics["method"] = "Original (No Unlearn)"
1210
+ original_metrics["time_seconds"] = 0
1211
+ print(f" Original: {original_metrics}")
1212
+
1213
+ results = [original_metrics]
1214
+
1215
+ # 6. Run baselines
1216
+ baselines = [
1217
+ ("Gradient Ascent", GradientAscentUnlearning(epochs=5, lr=0.01)),
1218
+ ("Fine-tuning", FineTuneUnlearning(epochs=10, lr=0.001)),
1219
+ ("Fisher Forgetting", FisherForgetting(noise_scale=0.01)),
1220
+ ("Manifold Mixup (P1)", ManifoldMixupUnlearning(epochs=10, lr=0.005)),
1221
+ ("Ferrari (P2)", FerrariUnlearning(epochs=15, lr=0.005)),
1222
+ ]
1223
+
1224
+ print("[6/8] Running baselines...")
1225
+ for name, method in baselines:
1226
+ print(f" Running {name}...")
1227
+ t0 = time.time()
1228
+ unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices)
1229
+ elapsed = time.time() - t0
1230
+
1231
+ metrics = full_evaluation(
1232
+ unlearned, X_train_splits, y_train, X_test_splits, y_test,
1233
+ forget_indices, retain_indices, forget_class
1234
+ )
1235
+ metrics["method"] = name
1236
+ metrics["time_seconds"] = round(elapsed, 2)
1237
+ results.append(metrics)
1238
+ print(f" {name}: Forget={metrics['forget_acc']:.1f}%, "
1239
+ f"Retain={metrics['retain_acc']:.1f}%, MIA={metrics['mia_asr']:.1f}%")
1240
+
1241
+ # 7. Run UFUSC variants
1242
+ print("[7/8] Running UFUSC variants...")
1243
+ ufusc_variants = [
1244
+ ("UFUSC (Label Only)", UFUSC(mode="label_only", epochs=UNLEARN_EPOCHS)),
1245
+ ("UFUSC (Feature Only)", UFUSC(mode="feature_only", epochs=UNLEARN_EPOCHS)),
1246
+ ("UFUSC (Joint)", UFUSC(mode="joint", epochs=UNLEARN_EPOCHS)),
1247
+ ]
1248
+
1249
+ for name, method in ufusc_variants:
1250
+ print(f" Running {name}...")
1251
+ t0 = time.time()
1252
+ unlearned = method.unlearn(
1253
+ model, X_train_splits, y_train, forget_indices, retain_indices,
1254
+ num_classes=num_classes
1255
+ )
1256
+ elapsed = time.time() - t0
1257
+
1258
+ metrics = full_evaluation(
1259
+ unlearned, X_train_splits, y_train, X_test_splits, y_test,
1260
+ forget_indices, retain_indices, forget_class
1261
+ )
1262
+ metrics["method"] = name
1263
+ metrics["time_seconds"] = round(elapsed, 2)
1264
+ results.append(metrics)
1265
+ print(f" {name}: Forget={metrics['forget_acc']:.1f}%, "
1266
+ f"Retain={metrics['retain_acc']:.1f}%, MIA={metrics['mia_asr']:.1f}%")
1267
+
1268
+ # 8. Summary
1269
+ print(f"\n[8/8] {dataset_name} Summary:")
1270
+ print(f" {'Method':<25} {'Test':>8} {'Forget':>8} {'Retain':>8} {'MIA':>8} {'Sens':>8}")
1271
+ print(f" {'-'*73}")
1272
+ for r in results:
1273
+ print(f" {r['method']:<25} {r['test_acc']:>7.2f}% {r['forget_acc']:>7.2f}% "
1274
+ f"{r['retain_acc']:>7.2f}% {r['mia_asr']:>7.1f}% {r['feature_sensitivity']:>7.3f}")
1275
+
1276
+ return results
1277
+
1278
+
1279
+ # ============================================================================
1280
+ # Ablation Study
1281
+ # ============================================================================
1282
+
1283
+ def run_ablation_study(dataset_name="MNIST"):
1284
+ """
1285
+ Ablation study on UFUSC hyperparameters: α, β, γ, and unlearning epochs.
1286
+
1287
+ Tests the impact of each component by varying one hyperparameter
1288
+ while keeping others at their default values.
1289
+
1290
+ Returns:
1291
+ list of ablation result dicts
1292
+ """
1293
+ set_seed()
1294
+ print(f"\n{'='*70}")
1295
+ print(f" ABLATION STUDY: {dataset_name}")
1296
+ print(f"{'='*70}")
1297
+
1298
+ # Load and prepare
1299
+ X_train, y_train, X_test, y_test, num_classes, feature_dim = load_dataset(dataset_name)
1300
+ X_train_splits = list(split_features_vfl(X_train))
1301
+ X_test_splits = list(split_features_vfl(X_test))
1302
+ feature_dims = [xs.shape[1] for xs in X_train_splits]
1303
+
1304
+ model = VFLFramework(feature_dims, num_classes)
1305
+ model.train_model(X_train_splits, y_train, X_test_splits, y_test, epochs=TRAIN_EPOCHS, verbose=False)
1306
+
1307
+ forget_indices, retain_indices = create_forget_retain_split(y_train)
1308
+
1309
+ ablation_results = []
1310
+
1311
+ # Ablation 1: Vary α (CFL weight)
1312
+ print("\n Ablation: α (CFL weight)")
1313
+ for alpha_val in [0.0, 0.5, 1.0, 2.0, 5.0]:
1314
+ method = UFUSC(mode="joint", alpha=alpha_val, beta=BETA, gamma=GAMMA, epochs=UNLEARN_EPOCHS)
1315
+ unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
1316
+ metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
1317
+ forget_indices, retain_indices)
1318
+ metrics["ablation_param"] = "alpha"
1319
+ metrics["ablation_value"] = alpha_val
1320
+ ablation_results.append(metrics)
1321
+ print(f" α={alpha_val}: Forget={metrics['forget_acc']:.1f}%, Retain={metrics['retain_acc']:.1f}%")
1322
+
1323
+ # Ablation 2: Vary β (Sensitivity weight)
1324
+ print("\n Ablation: β (Sensitivity weight)")
1325
+ for beta_val in [0.0, 0.25, 0.5, 1.0, 2.0]:
1326
+ method = UFUSC(mode="joint", alpha=ALPHA, beta=beta_val, gamma=GAMMA, epochs=UNLEARN_EPOCHS)
1327
+ unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
1328
+ metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
1329
+ forget_indices, retain_indices)
1330
+ metrics["ablation_param"] = "beta"
1331
+ metrics["ablation_value"] = beta_val
1332
+ ablation_results.append(metrics)
1333
+ print(f" β={beta_val}: Forget={metrics['forget_acc']:.1f}%, Retain={metrics['retain_acc']:.1f}%")
1334
+
1335
+ # Ablation 3: Vary γ (Anchor weight)
1336
+ print("\n Ablation: γ (Anchor weight)")
1337
+ for gamma_val in [0.0, 0.1, 0.3, 0.5, 1.0]:
1338
+ method = UFUSC(mode="joint", alpha=ALPHA, beta=BETA, gamma=gamma_val, epochs=UNLEARN_EPOCHS)
1339
+ unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
1340
+ metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
1341
+ forget_indices, retain_indices)
1342
+ metrics["ablation_param"] = "gamma"
1343
+ metrics["ablation_value"] = gamma_val
1344
+ ablation_results.append(metrics)
1345
+ print(f" γ={gamma_val}: Forget={metrics['forget_acc']:.1f}%, Retain={metrics['retain_acc']:.1f}%")
1346
+
1347
+ # Ablation 4: Vary unlearning epochs
1348
+ print("\n Ablation: Unlearning epochs")
1349
+ for ep in [1, 5, 10, 15, 20]:
1350
+ method = UFUSC(mode="joint", alpha=ALPHA, beta=BETA, gamma=GAMMA, epochs=ep)
1351
+ unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
1352
+ metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
1353
+ forget_indices, retain_indices)
1354
+ metrics["ablation_param"] = "epochs"
1355
+ metrics["ablation_value"] = ep
1356
+ ablation_results.append(metrics)
1357
+ print(f" epochs={ep}: Forget={metrics['forget_acc']:.1f}%, Retain={metrics['retain_acc']:.1f}%")
1358
+
1359
+ return ablation_results
1360
+
1361
+
1362
+ # ============================================================================
1363
+ # Scalability Analysis
1364
+ # ============================================================================
1365
+
1366
+ def run_scalability_analysis(dataset_name="MNIST"):
1367
+ """
1368
+ Scalability analysis: test UFUSC with varying number of passive parties K.
1369
+
1370
+ Tests K = 2, 3, 4, 6 to see how the method scales in VFL settings
1371
+ with different numbers of data holders.
1372
+
1373
+ Returns:
1374
+ list of scalability result dicts
1375
+ """
1376
+ set_seed()
1377
+ print(f"\n{'='*70}")
1378
+ print(f" SCALABILITY ANALYSIS: {dataset_name}")
1379
+ print(f"{'='*70}")
1380
+
1381
+ X_train, y_train, X_test, y_test, num_classes, feature_dim = load_dataset(dataset_name)
1382
+
1383
+ scalability_results = []
1384
+
1385
+ for K in [2, 3, 4, 6]:
1386
+ print(f"\n K={K} parties...")
1387
+ X_train_splits = list(split_features_vfl(X_train, K))
1388
+ X_test_splits = list(split_features_vfl(X_test, K))
1389
+ feature_dims = [xs.shape[1] for xs in X_train_splits]
1390
+
1391
+ model = VFLFramework(feature_dims, num_classes, num_parties=K)
1392
+ model.train_model(X_train_splits, y_train, X_test_splits, y_test,
1393
+ epochs=TRAIN_EPOCHS, verbose=False)
1394
+
1395
+ forget_indices, retain_indices = create_forget_retain_split(y_train)
1396
+
1397
+ # Evaluate original
1398
+ orig_metrics = full_evaluation(model, X_train_splits, y_train, X_test_splits, y_test,
1399
+ forget_indices, retain_indices)
1400
+
1401
+ # Run UFUSC-Joint
1402
+ ufusc = UFUSC(mode="joint", epochs=UNLEARN_EPOCHS)
1403
+ t0 = time.time()
1404
+ unlearned = ufusc.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
1405
+ elapsed = time.time() - t0
1406
+
1407
+ ufusc_metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
1408
+ forget_indices, retain_indices)
1409
+
1410
+ result = {
1411
+ "K": K,
1412
+ "original_test_acc": orig_metrics["test_acc"],
1413
+ "original_forget_acc": orig_metrics["forget_acc"],
1414
+ "ufusc_test_acc": ufusc_metrics["test_acc"],
1415
+ "ufusc_forget_acc": ufusc_metrics["forget_acc"],
1416
+ "ufusc_retain_acc": ufusc_metrics["retain_acc"],
1417
+ "ufusc_mia_asr": ufusc_metrics["mia_asr"],
1418
+ "time_seconds": round(elapsed, 2)
1419
+ }
1420
+ scalability_results.append(result)
1421
+ print(f" K={K}: Original Test={orig_metrics['test_acc']:.1f}%, "
1422
+ f"UFUSC Forget={ufusc_metrics['forget_acc']:.1f}%, "
1423
+ f"Retain={ufusc_metrics['retain_acc']:.1f}%, Time={elapsed:.1f}s")
1424
+
1425
+ return scalability_results
1426
+
1427
+
1428
+ # ============================================================================
1429
+ # Visualization
1430
+ # ============================================================================
1431
+
1432
+ def create_visualizations(all_results, ablation_results=None, scalability_results=None):
1433
+ """
1434
+ Create all publication-quality figures.
1435
+
1436
+ Generates:
1437
+ - Comparison bar charts (1 per dataset)
1438
+ - Radar plots (1 per dataset)
1439
+ - Ablation study plot
1440
+ - Scalability analysis plot
1441
+ - Privacy-utility tradeoff plots (1 per dataset)
1442
+ """
1443
+ try:
1444
+ import matplotlib
1445
+ matplotlib.use('Agg')
1446
+ import matplotlib.pyplot as plt
1447
+ import seaborn as sns
1448
+ sns.set_theme(style="whitegrid")
1449
+ except ImportError:
1450
+ print("WARNING: matplotlib/seaborn not available. Skipping visualization.")
1451
+ return
1452
+
1453
+ colors = {
1454
+ "Original (No Unlearn)": "#95a5a6",
1455
+ "Gradient Ascent": "#e74c3c",
1456
+ "Fine-tuning": "#e67e22",
1457
+ "Fisher Forgetting": "#f39c12",
1458
+ "Manifold Mixup (P1)": "#27ae60",
1459
+ "Ferrari (P2)": "#2980b9",
1460
+ "UFUSC (Label Only)": "#8e44ad",
1461
+ "UFUSC (Feature Only)": "#1abc9c",
1462
+ "UFUSC (Joint)": "#c0392b",
1463
+ }
1464
+
1465
+ # ---- Comparison Bar Charts (one per dataset) ----
1466
+ for dataset_name, results in all_results.items():
1467
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
1468
+ fig.suptitle(f"{dataset_name} — Unlearning Method Comparison", fontsize=16, fontweight='bold')
1469
+
1470
+ methods = [r["method"] for r in results]
1471
+ method_colors = [colors.get(m, "#333333") for m in methods]
1472
+
1473
+ # Forget Accuracy (lower is better)
1474
+ vals = [r["forget_acc"] for r in results]
1475
+ axes[0].barh(methods, vals, color=method_colors)
1476
+ axes[0].set_xlabel("Forget Accuracy (%) ↓")
1477
+ axes[0].set_title("Forgetting Quality")
1478
+ axes[0].invert_yaxis()
1479
+
1480
+ # Retain Accuracy (higher is better)
1481
+ vals = [r["retain_acc"] for r in results]
1482
+ axes[1].barh(methods, vals, color=method_colors)
1483
+ axes[1].set_xlabel("Retain Accuracy (%) ↑")
1484
+ axes[1].set_title("Utility Preservation")
1485
+ axes[1].invert_yaxis()
1486
+
1487
+ # MIA ASR (lower is better)
1488
+ vals = [r["mia_asr"] for r in results]
1489
+ axes[2].barh(methods, vals, color=method_colors)
1490
+ axes[2].set_xlabel("MIA ASR (%) ↓")
1491
+ axes[2].set_title("Privacy Protection")
1492
+ axes[2].axvline(x=50, color='red', linestyle='--', alpha=0.5, label='Random (50%)')
1493
+ axes[2].invert_yaxis()
1494
+ axes[2].legend()
1495
+
1496
+ plt.tight_layout()
1497
+ plt.savefig(f"figures/{dataset_name.replace('-', '_')}_comparison.png", dpi=150, bbox_inches='tight')
1498
+ plt.close()
1499
+ print(f" Saved: figures/{dataset_name.replace('-', '_')}_comparison.png")
1500
+
1501
+ # ---- Radar Plots (one per dataset) ----
1502
+ for dataset_name, results in all_results.items():
1503
+ # Select key methods for radar
1504
+ key_methods = ["Gradient Ascent", "Manifold Mixup (P1)", "Ferrari (P2)", "UFUSC (Joint)"]
1505
+ key_results = [r for r in results if r["method"] in key_methods]
1506
+
1507
+ if len(key_results) < 2:
1508
+ continue
1509
+
1510
+ categories = ["Retain Acc", "1 - Forget Acc", "1 - MIA ASR", "Low Sensitivity"]
1511
+ N = len(categories)
1512
+ angles = [n / float(N) * 2 * np.pi for n in range(N)]
1513
+ angles += angles[:1] # Close the polygon
1514
+
1515
+ fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
1516
+ ax.set_title(f"{dataset_name} — Method Radar Comparison", fontsize=14, fontweight='bold', pad=20)
1517
+
1518
+ for r in key_results:
1519
+ values = [
1520
+ r["retain_acc"] / 100,
1521
+ (100 - r["forget_acc"]) / 100,
1522
+ (100 - r["mia_asr"]) / 100,
1523
+ max(0, 1 - r["feature_sensitivity"]),
1524
+ ]
1525
+ values += values[:1]
1526
+ color = colors.get(r["method"], "#333333")
1527
+ ax.plot(angles, values, 'o-', linewidth=2, label=r["method"], color=color)
1528
+ ax.fill(angles, values, alpha=0.1, color=color)
1529
+
1530
+ ax.set_xticks(angles[:-1])
1531
+ ax.set_xticklabels(categories)
1532
+ ax.set_ylim(0, 1)
1533
+ ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
1534
+
1535
+ plt.tight_layout()
1536
+ plt.savefig(f"figures/{dataset_name.replace('-', '_')}_radar.png", dpi=150, bbox_inches='tight')
1537
+ plt.close()
1538
+ print(f" Saved: figures/{dataset_name.replace('-', '_')}_radar.png")
1539
+
1540
+ # ---- Ablation Study Plot ----
1541
+ if ablation_results:
1542
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
1543
+ fig.suptitle("UFUSC Ablation Study (MNIST)", fontsize=16, fontweight='bold')
1544
+
1545
+ params = {"alpha": "α (CFL weight)", "beta": "β (Sensitivity weight)",
1546
+ "gamma": "γ (Anchor weight)", "epochs": "Unlearning Epochs"}
1547
+
1548
+ for idx, (param_key, param_label) in enumerate(params.items()):
1549
+ ax = axes[idx // 2][idx % 2]
1550
+ param_results = [r for r in ablation_results if r["ablation_param"] == param_key]
1551
+
1552
+ if not param_results:
1553
+ continue
1554
+
1555
+ x_vals = [r["ablation_value"] for r in param_results]
1556
+ forget_vals = [r["forget_acc"] for r in param_results]
1557
+ retain_vals = [r["retain_acc"] for r in param_results]
1558
+
1559
+ ax.plot(x_vals, forget_vals, 's-', color='#e74c3c', label='Forget Acc ↓', linewidth=2, markersize=8)
1560
+ ax.plot(x_vals, retain_vals, 'o-', color='#2980b9', label='Retain Acc ↑', linewidth=2, markersize=8)
1561
+ ax.set_xlabel(param_label)
1562
+ ax.set_ylabel("Accuracy (%)")
1563
+ ax.set_title(f"Effect of {param_label}")
1564
+ ax.legend()
1565
+ ax.grid(True, alpha=0.3)
1566
+
1567
+ plt.tight_layout()
1568
+ plt.savefig("figures/ablation_study.png", dpi=150, bbox_inches='tight')
1569
+ plt.close()
1570
+ print(" Saved: figures/ablation_study.png")
1571
+
1572
+ # ---- Scalability Analysis Plot ----
1573
+ if scalability_results:
1574
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
1575
+ fig.suptitle("UFUSC Scalability Analysis (Varying K)", fontsize=14, fontweight='bold')
1576
+
1577
+ ks = [r["K"] for r in scalability_results]
1578
+
1579
+ # Accuracy metrics
1580
+ axes[0].plot(ks, [r["ufusc_forget_acc"] for r in scalability_results],
1581
+ 's-', color='#e74c3c', label='Forget Acc ↓', linewidth=2, markersize=8)
1582
+ axes[0].plot(ks, [r["ufusc_retain_acc"] for r in scalability_results],
1583
+ 'o-', color='#2980b9', label='Retain Acc ↑', linewidth=2, markersize=8)
1584
+ axes[0].plot(ks, [r["ufusc_mia_asr"] for r in scalability_results],
1585
+ '^-', color='#27ae60', label='MIA ASR ↓', linewidth=2, markersize=8)
1586
+ axes[0].set_xlabel("Number of Passive Parties (K)")
1587
+ axes[0].set_ylabel("Metric (%)")
1588
+ axes[0].set_title("Metrics vs K")
1589
+ axes[0].legend()
1590
+ axes[0].set_xticks(ks)
1591
+
1592
+ # Time
1593
+ axes[1].bar(ks, [r["time_seconds"] for r in scalability_results],
1594
+ color='#8e44ad', alpha=0.7)
1595
+ axes[1].set_xlabel("Number of Passive Parties (K)")
1596
+ axes[1].set_ylabel("Time (seconds)")
1597
+ axes[1].set_title("Unlearning Time vs K")
1598
+ axes[1].set_xticks(ks)
1599
+
1600
+ plt.tight_layout()
1601
+ plt.savefig("figures/scalability_analysis.png", dpi=150, bbox_inches='tight')
1602
+ plt.close()
1603
+ print(" Saved: figures/scalability_analysis.png")
1604
+
1605
+ # ---- Privacy-Utility Tradeoff Plots ----
1606
+ for dataset_name, results in all_results.items():
1607
+ fig, ax = plt.subplots(figsize=(10, 7))
1608
+ ax.set_title(f"{dataset_name} — Privacy-Utility Tradeoff", fontsize=14, fontweight='bold')
1609
+
1610
+ for r in results:
1611
+ if r["method"] == "Original (No Unlearn)":
1612
+ continue
1613
+ color = colors.get(r["method"], "#333333")
1614
+ marker = 'D' if 'UFUSC' in r["method"] else 'o'
1615
+ size = 200 if 'UFUSC' in r["method"] else 100
1616
+ ax.scatter(r["retain_acc"], 100 - r["mia_asr"],
1617
+ c=color, s=size, marker=marker,
1618
+ label=r["method"], edgecolors='black', linewidth=0.5, zorder=5)
1619
+
1620
+ ax.set_xlabel("Retain Accuracy (%) ↑ — Utility", fontsize=12)
1621
+ ax.set_ylabel("Privacy Protection (100 - MIA ASR) ↑", fontsize=12)
1622
+ ax.legend(fontsize=9, loc='best')
1623
+ ax.grid(True, alpha=0.3)
1624
+
1625
+ # Annotate ideal region
1626
+ ax.annotate("← Better Privacy & Utility →",
1627
+ xy=(0.5, 0.02), xycoords='axes fraction',
1628
+ fontsize=10, ha='center', alpha=0.5, style='italic')
1629
+
1630
+ plt.tight_layout()
1631
+ plt.savefig(f"figures/{dataset_name.replace('-', '_')}_tradeoff.png", dpi=150, bbox_inches='tight')
1632
+ plt.close()
1633
+ print(f" Saved: figures/{dataset_name.replace('-', '_')}_tradeoff.png")
1634
+
1635
+
1636
+ # ============================================================================
1637
+ # Main Execution
1638
+ # ============================================================================
1639
+
1640
+ def main():
1641
+ """
1642
+ Full experimental pipeline:
1643
+ 1. Run experiments on MNIST, Fashion-MNIST, CIFAR-10
1644
+ 2. Run ablation study on MNIST
1645
+ 3. Run scalability analysis on MNIST
1646
+ 4. Generate all visualizations
1647
+ 5. Save results to JSON
1648
+ """
1649
+ print("=" * 70)
1650
+ print(" UFUSC: Unified Federated Unlearning via")
1651
+ print(" Sensitivity-Guided Contrastive Forgetting")
1652
+ print("=" * 70)
1653
+ print(f" Device: {DEVICE}")
1654
+ print(f" Seed: {SEED}")
1655
+ print(f" VFL Parties: {NUM_PASSIVE_PARTIES}")
1656
+ print(f" Batch Size: {BATCH_SIZE}")
1657
+ print(f" Train Epochs: {TRAIN_EPOCHS}")
1658
+ print(f" Unlearn Epochs: {UNLEARN_EPOCHS}")
1659
+ print(f" Forget Ratio: {FORGET_RATIO}")
1660
+ print(f" UFUSC params: α={ALPHA}, β={BETA}, γ={GAMMA}, Ω={OMEGA}, τ={TAU}")
1661
+ print()
1662
+
1663
+ # ---- Main Experiments ----
1664
+ all_results = {}
1665
+ for dataset_name in ["MNIST", "Fashion-MNIST", "CIFAR-10"]:
1666
+ results = run_single_experiment(dataset_name)
1667
+ all_results[dataset_name] = results
1668
+
1669
+ # Save main results
1670
+ with open("results/all_results.json", "w") as f:
1671
+ json.dump(all_results, f, indent=2)
1672
+ print("\n✓ Saved: results/all_results.json")
1673
+
1674
+ # ---- Ablation Study ----
1675
+ ablation_results = run_ablation_study("MNIST")
1676
+ with open("results/ablation_results.json", "w") as f:
1677
+ json.dump(ablation_results, f, indent=2)
1678
+ print("✓ Saved: results/ablation_results.json")
1679
+
1680
+ # ---- Scalability Analysis ----
1681
+ scalability_results = run_scalability_analysis("MNIST")
1682
+ with open("results/scalability_results.json", "w") as f:
1683
+ json.dump(scalability_results, f, indent=2)
1684
+ print("✓ Saved: results/scalability_results.json")
1685
+
1686
+ # ---- Visualizations ----
1687
+ print("\n" + "=" * 70)
1688
+ print(" GENERATING VISUALIZATIONS")
1689
+ print("=" * 70)
1690
+ create_visualizations(all_results, ablation_results, scalability_results)
1691
+
1692
+ # ---- Final Summary ----
1693
+ print("\n" + "=" * 70)
1694
+ print(" FINAL SUMMARY")
1695
+ print("=" * 70)
1696
+
1697
+ for dataset_name, results in all_results.items():
1698
+ joint = next((r for r in results if r["method"] == "UFUSC (Joint)"), None)
1699
+ if joint:
1700
+ print(f"\n {dataset_name}:")
1701
+ print(f" UFUSC-Joint → Retain: {joint['retain_acc']:.1f}%, "
1702
+ f"Forget: {joint['forget_acc']:.1f}%, MIA: {joint['mia_asr']:.1f}%")
1703
+
1704
+ print("\n All experiments complete!")
1705
+ print(f" Results: results/all_results.json")
1706
+ print(f" Ablation: results/ablation_results.json")
1707
+ print(f" Scalability: results/scalability_results.json")
1708
+ print(f" Figures: figures/*.png")
1709
+ print("=" * 70)
1710
+
1711
+
1712
+ if __name__ == "__main__":
1713
+ main()