Ellaft commited on
Commit
34b5fb4
·
verified ·
1 Parent(s): 0c3f8a7

Add run_ablations_v2.py: v2 ablation runner with OGM-GE experiments

Browse files
Files changed (1) hide show
  1. src/run_ablations_v2.py +138 -0
src/run_ablations_v2.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal PC Fault Detection - Ablation Runner v2
3
+ ====================================================
4
+ Runs all ablation experiments using train_v2.py with OGM-GE.
5
+
6
+ Usage: python run_ablations_v2.py [--quick_test]
7
+ """
8
+
9
+ import os, sys, json, subprocess
10
+ from config import FAULT_CLASSES
11
+
12
+
13
+ def run_ablation_suite(quick_test=False):
14
+ experiments = [
15
+ {
16
+ "name": "Multimodal + LoRA + OGM-GE",
17
+ "args": [
18
+ "--mode", "multimodal", "--finetune", "lora",
19
+ "--fusion", "concat", "--eval_robustness",
20
+ "--output_dir", "./results/multimodal_lora_ogmge",
21
+ "--hub_model_id", "Ellaft/pc-fault-multimodal-lora-ogmge",
22
+ ],
23
+ },
24
+ {
25
+ "name": "Visual Only + LoRA",
26
+ "args": [
27
+ "--mode", "visual_only", "--finetune", "lora",
28
+ "--output_dir", "./results/visual_only_lora",
29
+ "--hub_model_id", "Ellaft/pc-fault-visual-only",
30
+ ],
31
+ },
32
+ {
33
+ "name": "Audio Only + LoRA",
34
+ "args": [
35
+ "--mode", "audio_only", "--finetune", "lora",
36
+ "--output_dir", "./results/audio_only_lora",
37
+ "--hub_model_id", "Ellaft/pc-fault-audio-only",
38
+ ],
39
+ },
40
+ {
41
+ "name": "Multimodal + Full FT + OGM-GE",
42
+ "args": [
43
+ "--mode", "multimodal", "--finetune", "full",
44
+ "--lr", "2e-5", "--fusion", "concat",
45
+ "--output_dir", "./results/multimodal_full_ft_ogmge",
46
+ "--hub_model_id", "Ellaft/pc-fault-multimodal-full-ft",
47
+ ],
48
+ },
49
+ {
50
+ "name": "Multimodal + Linear Probe + OGM-GE",
51
+ "args": [
52
+ "--mode", "multimodal", "--finetune", "linear_probe",
53
+ "--lr", "1e-3", "--fusion", "concat",
54
+ "--output_dir", "./results/multimodal_linear_probe_ogmge",
55
+ "--hub_model_id", "Ellaft/pc-fault-multimodal-linear-probe",
56
+ ],
57
+ },
58
+ {
59
+ "name": "Multimodal + LoRA + High Dropout + OGM-GE",
60
+ "args": [
61
+ "--mode", "multimodal", "--finetune", "lora",
62
+ "--modality_dropout", "0.5", "--fusion", "concat",
63
+ "--eval_robustness",
64
+ "--output_dir", "./results/multimodal_robust_ogmge",
65
+ "--hub_model_id", "Ellaft/pc-fault-multimodal-robust",
66
+ ],
67
+ },
68
+ ]
69
+
70
+ if quick_test:
71
+ for exp in experiments:
72
+ exp["args"].append("--quick_test")
73
+
74
+ all_results = {}
75
+ for exp in experiments:
76
+ print(f"\n{'='*60}")
77
+ print(f"Running: {exp['name']}")
78
+ print(f"{'='*60}")
79
+ result = subprocess.run(
80
+ [sys.executable, "train_v2.py"] + exp["args"],
81
+ capture_output=False)
82
+
83
+ results_dir = exp["args"][exp["args"].index("--output_dir") + 1]
84
+ rpath = os.path.join(results_dir, "results_v2.json")
85
+ if os.path.exists(rpath):
86
+ all_results[exp["name"]] = json.load(open(rpath))
87
+ else:
88
+ all_results[exp["name"]] = {"error": f"rc={result.returncode}"}
89
+
90
+ # Summary table
91
+ print(f"\n\n{'='*80}")
92
+ print(f"ABLATION STUDY RESULTS (v2 — with OGM-GE anti-collapse)")
93
+ print(f"{'='*80}")
94
+ print(f"{'Experiment':<42s} {'Accuracy':>10s} {'Macro-F1':>10s} {'Weighted-F1':>12s}")
95
+ print("-" * 76)
96
+ for name, r in all_results.items():
97
+ if "error" in r:
98
+ print(f"{name:<42s} {'ERROR':>10s}")
99
+ else:
100
+ m = r["final_metrics"]
101
+ print(f"{name:<42s} {m['accuracy']:>10.4f} {m['macro_f1']:>10.4f} {m['weighted_f1']:>12.4f}")
102
+
103
+ # Robustness comparison (if available)
104
+ print(f"\n{'='*80}")
105
+ print(f"MISSING MODALITY ROBUSTNESS")
106
+ print(f"{'='*80}")
107
+ print(f"{'Experiment':<42s} {'Both':>10s} {'Vis-Only':>10s} {'Aud-Only':>10s}")
108
+ print("-" * 76)
109
+ for name, r in all_results.items():
110
+ if "robustness" in r:
111
+ rb = r["robustness"]
112
+ print(f"{name:<42s} "
113
+ f"{rb.get('both_modalities', {}).get('macro_f1', 0):>10.4f} "
114
+ f"{rb.get('visual_only', {}).get('macro_f1', 0):>10.4f} "
115
+ f"{rb.get('audio_only', {}).get('macro_f1', 0):>10.4f}")
116
+
117
+ # v1 vs v2 comparison reminder
118
+ print(f"\n{'='*80}")
119
+ print("v1 BASELINE (for comparison):")
120
+ print(" Both modalities: Acc=0.9423 F1=0.9481")
121
+ print(" Audio only: Acc=0.9231 F1=0.9316")
122
+ print(" Visual only: Acc=0.2308 F1=0.0750 ← MODALITY COLLAPSE")
123
+ print(f"{'='*80}")
124
+
125
+ # Save summary
126
+ os.makedirs("./results", exist_ok=True)
127
+ with open("./results/ablation_summary_v2.json", "w") as f:
128
+ json.dump(all_results, f, indent=2)
129
+ print(f"\nFull results saved to ./results/ablation_summary_v2.json")
130
+
131
+ return all_results
132
+
133
+
134
+ if __name__ == "__main__":
135
+ import argparse
136
+ parser = argparse.ArgumentParser()
137
+ parser.add_argument("--quick_test", action="store_true")
138
+ run_ablation_suite(quick_test=parser.parse_args().quick_test)