File size: 5,111 Bytes
34b5fb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Multimodal PC Fault Detection - Ablation Runner v2
====================================================
Runs all ablation experiments using train_v2.py with OGM-GE.

Usage: python run_ablations_v2.py [--quick_test]
"""

import os, sys, json, subprocess
from config import FAULT_CLASSES


def run_ablation_suite(quick_test=False):
    experiments = [
        {
            "name": "Multimodal + LoRA + OGM-GE",
            "args": [
                "--mode", "multimodal", "--finetune", "lora",
                "--fusion", "concat", "--eval_robustness",
                "--output_dir", "./results/multimodal_lora_ogmge",
                "--hub_model_id", "Ellaft/pc-fault-multimodal-lora-ogmge",
            ],
        },
        {
            "name": "Visual Only + LoRA",
            "args": [
                "--mode", "visual_only", "--finetune", "lora",
                "--output_dir", "./results/visual_only_lora",
                "--hub_model_id", "Ellaft/pc-fault-visual-only",
            ],
        },
        {
            "name": "Audio Only + LoRA",
            "args": [
                "--mode", "audio_only", "--finetune", "lora",
                "--output_dir", "./results/audio_only_lora",
                "--hub_model_id", "Ellaft/pc-fault-audio-only",
            ],
        },
        {
            "name": "Multimodal + Full FT + OGM-GE",
            "args": [
                "--mode", "multimodal", "--finetune", "full",
                "--lr", "2e-5", "--fusion", "concat",
                "--output_dir", "./results/multimodal_full_ft_ogmge",
                "--hub_model_id", "Ellaft/pc-fault-multimodal-full-ft",
            ],
        },
        {
            "name": "Multimodal + Linear Probe + OGM-GE",
            "args": [
                "--mode", "multimodal", "--finetune", "linear_probe",
                "--lr", "1e-3", "--fusion", "concat",
                "--output_dir", "./results/multimodal_linear_probe_ogmge",
                "--hub_model_id", "Ellaft/pc-fault-multimodal-linear-probe",
            ],
        },
        {
            "name": "Multimodal + LoRA + High Dropout + OGM-GE",
            "args": [
                "--mode", "multimodal", "--finetune", "lora",
                "--modality_dropout", "0.5", "--fusion", "concat",
                "--eval_robustness",
                "--output_dir", "./results/multimodal_robust_ogmge",
                "--hub_model_id", "Ellaft/pc-fault-multimodal-robust",
            ],
        },
    ]

    if quick_test:
        for exp in experiments:
            exp["args"].append("--quick_test")

    all_results = {}
    for exp in experiments:
        print(f"\n{'='*60}")
        print(f"Running: {exp['name']}")
        print(f"{'='*60}")
        result = subprocess.run(
            [sys.executable, "train_v2.py"] + exp["args"],
            capture_output=False)

        results_dir = exp["args"][exp["args"].index("--output_dir") + 1]
        rpath = os.path.join(results_dir, "results_v2.json")
        if os.path.exists(rpath):
            all_results[exp["name"]] = json.load(open(rpath))
        else:
            all_results[exp["name"]] = {"error": f"rc={result.returncode}"}

    # Summary table
    print(f"\n\n{'='*80}")
    print(f"ABLATION STUDY RESULTS (v2 — with OGM-GE anti-collapse)")
    print(f"{'='*80}")
    print(f"{'Experiment':<42s} {'Accuracy':>10s} {'Macro-F1':>10s} {'Weighted-F1':>12s}")
    print("-" * 76)
    for name, r in all_results.items():
        if "error" in r:
            print(f"{name:<42s} {'ERROR':>10s}")
        else:
            m = r["final_metrics"]
            print(f"{name:<42s} {m['accuracy']:>10.4f} {m['macro_f1']:>10.4f} {m['weighted_f1']:>12.4f}")

    # Robustness comparison (if available)
    print(f"\n{'='*80}")
    print(f"MISSING MODALITY ROBUSTNESS")
    print(f"{'='*80}")
    print(f"{'Experiment':<42s} {'Both':>10s} {'Vis-Only':>10s} {'Aud-Only':>10s}")
    print("-" * 76)
    for name, r in all_results.items():
        if "robustness" in r:
            rb = r["robustness"]
            print(f"{name:<42s} "
                  f"{rb.get('both_modalities', {}).get('macro_f1', 0):>10.4f} "
                  f"{rb.get('visual_only', {}).get('macro_f1', 0):>10.4f} "
                  f"{rb.get('audio_only', {}).get('macro_f1', 0):>10.4f}")

    # v1 vs v2 comparison reminder
    print(f"\n{'='*80}")
    print("v1 BASELINE (for comparison):")
    print("  Both modalities: Acc=0.9423 F1=0.9481")
    print("  Audio only:      Acc=0.9231 F1=0.9316")
    print("  Visual only:     Acc=0.2308 F1=0.0750  ← MODALITY COLLAPSE")
    print(f"{'='*80}")

    # Save summary
    os.makedirs("./results", exist_ok=True)
    with open("./results/ablation_summary_v2.json", "w") as f:
        json.dump(all_results, f, indent=2)
    print(f"\nFull results saved to ./results/ablation_summary_v2.json")

    return all_results


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick_test", action="store_true")
    run_ablation_suite(quick_test=parser.parse_args().quick_test)