Ellaft commited on
Commit
63eaaee
·
verified ·
1 Parent(s): 38fdf87

Add training script and ablation runner

Browse files
Files changed (1) hide show
  1. src/train.py +214 -0
src/train.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal PC Fault Detection - Training Script
3
+ =================================================
4
+ Usage:
5
+ python train.py --mode multimodal --finetune lora --eval_robustness
6
+ python train.py --mode visual_only --finetune lora --no_push
7
+ python train.py --mode audio_only --finetune lora --no_push
8
+ python train.py --mode multimodal --finetune full --lr 2e-5
9
+ python train.py --quick_test --no_push
10
+ """
11
+
12
+ import os, sys, json, argparse, time
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.utils.data import DataLoader
17
+ from torch.optim import AdamW
18
+ from torch.optim.lr_scheduler import OneCycleLR
19
+ from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support
20
+ from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES
21
+ from dataset import PCFaultDataset, multimodal_collate_fn
22
+ from models import create_model, get_processors
23
+
24
+
25
+ def compute_metrics(preds, labels, class_names=FAULT_CLASSES):
26
+ accuracy = accuracy_score(labels, preds)
27
+ precision, recall, f1, support = precision_recall_fscore_support(
28
+ labels, preds, average=None, labels=range(len(class_names)), zero_division=0)
29
+ macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
30
+ weighted_f1 = f1_score(labels, preds, average="weighted", zero_division=0)
31
+ conf_matrix = confusion_matrix(labels, preds, labels=range(len(class_names)))
32
+ metrics = {"accuracy": accuracy, "macro_f1": macro_f1, "weighted_f1": weighted_f1,
33
+ "confusion_matrix": conf_matrix.tolist(), "per_class": {}}
34
+ for i, name in enumerate(class_names):
35
+ metrics["per_class"][name] = {"precision": precision[i], "recall": recall[i], "f1": f1[i], "support": int(support[i])}
36
+ return metrics
37
+
38
+
39
+ class MultimodalTrainer:
40
+ def __init__(self, model, train_dataset, val_dataset, config, device):
41
+ self.model = model.to(device)
42
+ self.device, self.config = device, config
43
+ self.train_loader = DataLoader(train_dataset, batch_size=config.per_device_train_batch_size,
44
+ shuffle=True, collate_fn=multimodal_collate_fn, num_workers=2, pin_memory=True, drop_last=True)
45
+ self.val_loader = DataLoader(val_dataset, batch_size=config.per_device_eval_batch_size,
46
+ shuffle=False, collate_fn=multimodal_collate_fn, num_workers=2, pin_memory=True)
47
+ param_groups = self._get_param_groups()
48
+ self.optimizer = AdamW(param_groups, weight_decay=config.weight_decay)
49
+ total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps
50
+ self.scheduler = OneCycleLR(self.optimizer, max_lr=[pg["lr"] for pg in param_groups],
51
+ total_steps=total_steps, pct_start=config.warmup_ratio, anneal_strategy="cos")
52
+ self.scaler = torch.amp.GradScaler("cuda") if config.fp16 and device.type == "cuda" else None
53
+ self.best_metric, self.best_epoch = 0.0, 0
54
+ self.history = {"train_loss": [], "val_loss": [], "val_accuracy": [], "val_macro_f1": []}
55
+
56
+ def _get_param_groups(self):
57
+ lora_params, other_params = [], []
58
+ for name, param in self.model.named_parameters():
59
+ if not param.requires_grad: continue
60
+ (lora_params if "lora" in name.lower() else other_params).append(param)
61
+ groups = []
62
+ if lora_params: groups.append({"params": lora_params, "lr": self.config.lora_learning_rate})
63
+ if other_params: groups.append({"params": other_params, "lr": self.config.learning_rate})
64
+ if not groups: raise ValueError("No trainable parameters!")
65
+ return groups
66
+
67
+ def train_epoch(self, epoch):
68
+ self.model.train()
69
+ total_loss, num_batches = 0.0, 0
70
+ self.optimizer.zero_grad()
71
+ for batch_idx, batch in enumerate(self.train_loader):
72
+ pv = batch["pixel_values"].to(self.device)
73
+ av = batch["audio_values"].to(self.device)
74
+ labels = batch["labels"].to(self.device)
75
+ if self.scaler:
76
+ with torch.amp.autocast("cuda"):
77
+ outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
78
+ loss = outputs["loss"] / self.config.gradient_accumulation_steps
79
+ self.scaler.scale(loss).backward()
80
+ else:
81
+ outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
82
+ loss = outputs["loss"] / self.config.gradient_accumulation_steps
83
+ loss.backward()
84
+ total_loss += loss.item() * self.config.gradient_accumulation_steps
85
+ num_batches += 1
86
+ if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
87
+ if self.scaler:
88
+ self.scaler.unscale_(self.optimizer)
89
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
90
+ self.scaler.step(self.optimizer)
91
+ self.scaler.update()
92
+ else:
93
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
94
+ self.optimizer.step()
95
+ self.scheduler.step()
96
+ self.optimizer.zero_grad()
97
+ if (batch_idx + 1) % self.config.logging_steps == 0 or batch_idx == 0:
98
+ print(f" [Epoch {epoch+1}] Step {batch_idx+1}/{len(self.train_loader)} | Loss: {total_loss/num_batches:.4f} | LR: {self.optimizer.param_groups[0]['lr']:.2e}")
99
+ return total_loss / max(num_batches, 1)
100
+
101
+ @torch.no_grad()
102
+ def evaluate(self, modality_mask=None):
103
+ self.model.eval()
104
+ all_preds, all_labels, total_loss, num_batches = [], [], 0.0, 0
105
+ for batch in self.val_loader:
106
+ pv = batch["pixel_values"].to(self.device)
107
+ av = batch["audio_values"].to(self.device)
108
+ labels = batch["labels"].to(self.device)
109
+ if modality_mask:
110
+ if modality_mask.get("visual", 1.0) == 0.0: pv = torch.zeros_like(pv)
111
+ if modality_mask.get("audio", 1.0) == 0.0: av = torch.zeros_like(av)
112
+ outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
113
+ total_loss += outputs["loss"].item()
114
+ num_batches += 1
115
+ all_preds.extend(outputs["logits"].argmax(dim=-1).cpu().numpy())
116
+ all_labels.extend(labels.cpu().numpy())
117
+ metrics = compute_metrics(np.array(all_preds), np.array(all_labels))
118
+ metrics["val_loss"] = total_loss / max(num_batches, 1)
119
+ return metrics
120
+
121
+ def train(self):
122
+ print(f"\\nTraining: mode={self.model.mode}, epochs={self.config.num_epochs}, batch={self.config.per_device_train_batch_size}, device={self.device}")
123
+ for epoch in range(self.config.num_epochs):
124
+ t0 = time.time()
125
+ train_loss = self.train_epoch(epoch)
126
+ val_metrics = self.evaluate()
127
+ print(f"\\n[Epoch {epoch+1}/{self.config.num_epochs}] ({time.time()-t0:.1f}s) Train Loss: {train_loss:.4f} | Val Loss: {val_metrics['val_loss']:.4f} | Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['macro_f1']:.4f}")
128
+ self.history["train_loss"].append(train_loss)
129
+ self.history["val_loss"].append(val_metrics["val_loss"])
130
+ self.history["val_accuracy"].append(val_metrics["accuracy"])
131
+ self.history["val_macro_f1"].append(val_metrics["macro_f1"])
132
+ if val_metrics[self.config.metric_for_best_model] > self.best_metric:
133
+ self.best_metric = val_metrics[self.config.metric_for_best_model]
134
+ self.best_epoch = epoch + 1
135
+ os.makedirs(self.config.output_dir, exist_ok=True)
136
+ torch.save({"model_state_dict": self.model.state_dict(), "epoch": epoch + 1, "metrics": val_metrics},
137
+ os.path.join(self.config.output_dir, "best_model.pt"))
138
+ print(f" ✓ Best model saved (F1={self.best_metric:.4f})")
139
+ print(f"\\nTraining complete. Best epoch={self.best_epoch}, Best F1={self.best_metric:.4f}")
140
+ return self.history
141
+
142
+ def run_robustness_evaluation(self):
143
+ print("\\n=== Missing Modality Robustness ===")
144
+ results = {}
145
+ for name, mask in [("both", None), ("visual_only", {"visual": 1.0, "audio": 0.0}), ("audio_only", {"visual": 0.0, "audio": 1.0})]:
146
+ m = self.evaluate(modality_mask=mask)
147
+ results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]}
148
+ print(f" {name}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}")
149
+ return results
150
+
151
+
152
+ def main():
153
+ parser = argparse.ArgumentParser()
154
+ parser.add_argument("--mode", default="multimodal", choices=["multimodal", "visual_only", "audio_only"])
155
+ parser.add_argument("--finetune", default="lora", choices=["lora", "full", "linear_probe"])
156
+ parser.add_argument("--epochs", type=int); parser.add_argument("--batch_size", type=int)
157
+ parser.add_argument("--lr", type=float); parser.add_argument("--fusion", default="concat")
158
+ parser.add_argument("--modality_dropout", type=float); parser.add_argument("--output_dir", type=str)
159
+ parser.add_argument("--hub_model_id", type=str); parser.add_argument("--no_push", action="store_true")
160
+ parser.add_argument("--eval_robustness", action="store_true"); parser.add_argument("--quick_test", action="store_true")
161
+ args = parser.parse_args()
162
+
163
+ config = ExperimentConfig()
164
+ config.train.mode, config.train.finetune_method, config.model.fusion_type = args.mode, args.finetune, args.fusion
165
+ if args.epochs: config.train.num_epochs = args.epochs
166
+ if args.batch_size: config.train.per_device_train_batch_size = args.batch_size
167
+ if args.lr: config.train.learning_rate = config.train.lora_learning_rate = args.lr
168
+ if args.modality_dropout is not None: config.model.modality_dropout_p = args.modality_dropout
169
+ if args.output_dir: config.train.output_dir = args.output_dir
170
+ if args.hub_model_id: config.train.hub_model_id = args.hub_model_id
171
+ if args.no_push: config.train.push_to_hub = False
172
+ if args.quick_test:
173
+ config.train.num_epochs, config.train.per_device_train_batch_size = 2, 4
174
+ config.train.per_device_eval_batch_size, config.train.gradient_accumulation_steps = 4, 1
175
+ config.train.logging_steps = 2
176
+ if args.finetune != "lora": config.lora.enabled = False
177
+
178
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179
+ torch.manual_seed(config.train.seed); np.random.seed(config.train.seed)
180
+
181
+ vit_proc, ast_ext = get_processors(config.model)
182
+ train_ds = PCFaultDataset(config.data, config.model, "train", vit_proc, ast_ext, True)
183
+ val_ds = PCFaultDataset(config.data, config.model, "val", vit_proc, ast_ext, False)
184
+ model = create_model(config.model, config.lora, config.train.mode, config.train.finetune_method)
185
+ trainer = MultimodalTrainer(model, train_ds, val_ds, config.train, device)
186
+ history = trainer.train()
187
+
188
+ final = trainer.evaluate()
189
+ print(f"\\nFinal: Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}")
190
+ for cls, m in final["per_class"].items():
191
+ print(f" {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} F1:{m['f1']:.3f} N:{m['support']}")
192
+
193
+ if args.eval_robustness and config.train.mode == "multimodal":
194
+ trainer.run_robustness_evaluation()
195
+
196
+ os.makedirs(config.train.output_dir, exist_ok=True)
197
+ with open(os.path.join(config.train.output_dir, "results.json"), "w") as f:
198
+ json.dump({"experiment": config.experiment_name, "mode": config.train.mode, "finetune_method": config.train.finetune_method,
199
+ "final_metrics": {"accuracy": final["accuracy"], "macro_f1": final["macro_f1"], "weighted_f1": final["weighted_f1"],
200
+ "per_class": final["per_class"], "confusion_matrix": final["confusion_matrix"]},
201
+ "history": history, "best_epoch": trainer.best_epoch, "best_metric": trainer.best_metric}, f, indent=2)
202
+
203
+ if config.train.push_to_hub:
204
+ try:
205
+ from huggingface_hub import HfApi, login
206
+ login(token=os.environ.get("HF_TOKEN"))
207
+ HfApi().upload_folder(folder_path=config.train.output_dir, repo_id=config.train.hub_model_id,
208
+ repo_type="model", commit_message=f"Training: {config.experiment_name}")
209
+ print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}")
210
+ except Exception as e:
211
+ print(f"✗ Push failed: {e}")
212
+
213
+ if __name__ == "__main__":
214
+ main()