File size: 3,409 Bytes
e8c88c3
 
 
1067497
e8c88c3
1067497
 
 
e8c88c3
 
 
 
 
 
 
1067497
 
 
 
 
 
 
 
 
 
 
 
e8c88c3
 
1067497
 
e8c88c3
 
1067497
 
 
 
 
 
 
 
 
e8c88c3
 
1067497
 
 
 
 
 
 
 
 
 
e8c88c3
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Multimodal PC Fault Detection - Ablation Runner
=================================================
Runs all 6 ablation experiments.

Usage:
  python run_ablations.py                # Full run
  python run_ablations.py --quick_test   # Quick 2-epoch test
"""

import os, sys, json, subprocess


def run_ablation_suite(quick_test=False):
    experiments = [
        {"name": "Multimodal + LoRA", "args": ["--mode", "multimodal", "--finetune", "lora", "--eval_robustness",
            "--output_dir", "./results/multimodal_lora", "--hub_model_id", "Ellaft/pc-fault-multimodal-lora"]},
        {"name": "Visual Only + LoRA", "args": ["--mode", "visual_only", "--finetune", "lora",
            "--output_dir", "./results/visual_only_lora", "--hub_model_id", "Ellaft/pc-fault-visual-only"]},
        {"name": "Audio Only + LoRA", "args": ["--mode", "audio_only", "--finetune", "lora",
            "--output_dir", "./results/audio_only_lora", "--hub_model_id", "Ellaft/pc-fault-audio-only"]},
        {"name": "Multimodal + Full FT", "args": ["--mode", "multimodal", "--finetune", "full", "--lr", "2e-5",
            "--output_dir", "./results/multimodal_full_ft", "--hub_model_id", "Ellaft/pc-fault-multimodal-full-ft"]},
        {"name": "Multimodal + Linear Probe", "args": ["--mode", "multimodal", "--finetune", "linear_probe", "--lr", "1e-3",
            "--output_dir", "./results/multimodal_linear_probe", "--hub_model_id", "Ellaft/pc-fault-multimodal-linear-probe"]},
        {"name": "Multimodal + High Dropout", "args": ["--mode", "multimodal", "--finetune", "lora", "--modality_dropout", "0.5",
            "--eval_robustness", "--output_dir", "./results/multimodal_robust", "--hub_model_id", "Ellaft/pc-fault-multimodal-robust"]},
    ]
    if quick_test:
        for exp in experiments: exp["args"].extend(["--quick_test", "--no_push"])

    all_results = {}
    for exp in experiments:
        print(f"\n{'='*60}\nRunning: {exp['name']}\n{'='*60}")
        subprocess.run([sys.executable, "train.py"] + exp["args"], capture_output=False)
        rdir = exp["args"][exp["args"].index("--output_dir") + 1]
        rpath = os.path.join(rdir, "results.json")
        all_results[exp["name"]] = json.load(open(rpath)) if os.path.exists(rpath) else {"error": "failed"}

    print(f"\n\n{'='*80}\nABLATION RESULTS\n{'='*80}")
    print(f"{'Experiment':<35s} {'Accuracy':>10s} {'Macro-F1':>10s}")
    print("-" * 58)
    for name, r in all_results.items():
        if "error" in r: print(f"{name:<35s} {'ERROR':>10s}")
        else: m = r["final_metrics"]; print(f"{name:<35s} {m['accuracy']:>10.4f} {m['macro_f1']:>10.4f}")

    print(f"\n{'='*80}\nROBUSTNESS (missing modality F1)\n{'='*80}")
    for name, r in all_results.items():
        if "robustness" in r:
            rb = r["robustness"]
            print(f"{name:<35s} both={rb.get('both',{}).get('macro_f1',0):.4f} "
                  f"vis={rb.get('visual_only',{}).get('macro_f1',0):.4f} "
                  f"aud={rb.get('audio_only',{}).get('macro_f1',0):.4f}")

    os.makedirs("./results", exist_ok=True)
    with open("./results/ablation_summary.json", "w") as f: json.dump(all_results, f, indent=2)
    return all_results


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick_test", action="store_true")
    run_ablation_suite(quick_test=parser.parse_args().quick_test)