| """ |
| Multimodal PC Fault Detection - Ablation Runner |
| ================================================= |
| Runs all 6 ablation experiments. |
| |
| Usage: |
| python run_ablations.py # Full run |
| python run_ablations.py --quick_test # Quick 2-epoch test |
| """ |
|
|
| import os, sys, json, subprocess |
|
|
|
|
| def run_ablation_suite(quick_test=False): |
| experiments = [ |
| {"name": "Multimodal + LoRA", "args": ["--mode", "multimodal", "--finetune", "lora", "--eval_robustness", |
| "--output_dir", "./results/multimodal_lora", "--hub_model_id", "Ellaft/pc-fault-multimodal-lora"]}, |
| {"name": "Visual Only + LoRA", "args": ["--mode", "visual_only", "--finetune", "lora", |
| "--output_dir", "./results/visual_only_lora", "--hub_model_id", "Ellaft/pc-fault-visual-only"]}, |
| {"name": "Audio Only + LoRA", "args": ["--mode", "audio_only", "--finetune", "lora", |
| "--output_dir", "./results/audio_only_lora", "--hub_model_id", "Ellaft/pc-fault-audio-only"]}, |
| {"name": "Multimodal + Full FT", "args": ["--mode", "multimodal", "--finetune", "full", "--lr", "2e-5", |
| "--output_dir", "./results/multimodal_full_ft", "--hub_model_id", "Ellaft/pc-fault-multimodal-full-ft"]}, |
| {"name": "Multimodal + Linear Probe", "args": ["--mode", "multimodal", "--finetune", "linear_probe", "--lr", "1e-3", |
| "--output_dir", "./results/multimodal_linear_probe", "--hub_model_id", "Ellaft/pc-fault-multimodal-linear-probe"]}, |
| {"name": "Multimodal + High Dropout", "args": ["--mode", "multimodal", "--finetune", "lora", "--modality_dropout", "0.5", |
| "--eval_robustness", "--output_dir", "./results/multimodal_robust", "--hub_model_id", "Ellaft/pc-fault-multimodal-robust"]}, |
| ] |
| if quick_test: |
| for exp in experiments: exp["args"].extend(["--quick_test", "--no_push"]) |
|
|
| all_results = {} |
| for exp in experiments: |
| print(f"\n{'='*60}\nRunning: {exp['name']}\n{'='*60}") |
| subprocess.run([sys.executable, "train.py"] + exp["args"], capture_output=False) |
| rdir = exp["args"][exp["args"].index("--output_dir") + 1] |
| rpath = os.path.join(rdir, "results.json") |
| all_results[exp["name"]] = json.load(open(rpath)) if os.path.exists(rpath) else {"error": "failed"} |
|
|
| print(f"\n\n{'='*80}\nABLATION RESULTS\n{'='*80}") |
| print(f"{'Experiment':<35s} {'Accuracy':>10s} {'Macro-F1':>10s}") |
| print("-" * 58) |
| for name, r in all_results.items(): |
| if "error" in r: print(f"{name:<35s} {'ERROR':>10s}") |
| else: m = r["final_metrics"]; print(f"{name:<35s} {m['accuracy']:>10.4f} {m['macro_f1']:>10.4f}") |
|
|
| print(f"\n{'='*80}\nROBUSTNESS (missing modality F1)\n{'='*80}") |
| for name, r in all_results.items(): |
| if "robustness" in r: |
| rb = r["robustness"] |
| print(f"{name:<35s} both={rb.get('both',{}).get('macro_f1',0):.4f} " |
| f"vis={rb.get('visual_only',{}).get('macro_f1',0):.4f} " |
| f"aud={rb.get('audio_only',{}).get('macro_f1',0):.4f}") |
|
|
| os.makedirs("./results", exist_ok=True) |
| with open("./results/ablation_summary.json", "w") as f: json.dump(all_results, f, indent=2) |
| return all_results |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--quick_test", action="store_true") |
| run_ablation_suite(quick_test=parser.parse_args().quick_test) |
|
|