""" Multimodal PC Fault Detection - Ablation Runner v2 ==================================================== Runs all ablation experiments using train_v2.py with OGM-GE. Usage: python run_ablations_v2.py [--quick_test] """ import os, sys, json, subprocess from config import FAULT_CLASSES def run_ablation_suite(quick_test=False): experiments = [ { "name": "Multimodal + LoRA + OGM-GE", "args": [ "--mode", "multimodal", "--finetune", "lora", "--fusion", "concat", "--eval_robustness", "--output_dir", "./results/multimodal_lora_ogmge", "--hub_model_id", "Ellaft/pc-fault-multimodal-lora-ogmge", ], }, { "name": "Visual Only + LoRA", "args": [ "--mode", "visual_only", "--finetune", "lora", "--output_dir", "./results/visual_only_lora", "--hub_model_id", "Ellaft/pc-fault-visual-only", ], }, { "name": "Audio Only + LoRA", "args": [ "--mode", "audio_only", "--finetune", "lora", "--output_dir", "./results/audio_only_lora", "--hub_model_id", "Ellaft/pc-fault-audio-only", ], }, { "name": "Multimodal + Full FT + OGM-GE", "args": [ "--mode", "multimodal", "--finetune", "full", "--lr", "2e-5", "--fusion", "concat", "--output_dir", "./results/multimodal_full_ft_ogmge", "--hub_model_id", "Ellaft/pc-fault-multimodal-full-ft", ], }, { "name": "Multimodal + Linear Probe + OGM-GE", "args": [ "--mode", "multimodal", "--finetune", "linear_probe", "--lr", "1e-3", "--fusion", "concat", "--output_dir", "./results/multimodal_linear_probe_ogmge", "--hub_model_id", "Ellaft/pc-fault-multimodal-linear-probe", ], }, { "name": "Multimodal + LoRA + High Dropout + OGM-GE", "args": [ "--mode", "multimodal", "--finetune", "lora", "--modality_dropout", "0.5", "--fusion", "concat", "--eval_robustness", "--output_dir", "./results/multimodal_robust_ogmge", "--hub_model_id", "Ellaft/pc-fault-multimodal-robust", ], }, ] if quick_test: for exp in experiments: exp["args"].append("--quick_test") all_results = {} for exp in experiments: print(f"\n{'='*60}") print(f"Running: {exp['name']}") print(f"{'='*60}") result = subprocess.run( [sys.executable, "train_v2.py"] + exp["args"], capture_output=False) results_dir = exp["args"][exp["args"].index("--output_dir") + 1] rpath = os.path.join(results_dir, "results_v2.json") if os.path.exists(rpath): all_results[exp["name"]] = json.load(open(rpath)) else: all_results[exp["name"]] = {"error": f"rc={result.returncode}"} # Summary table print(f"\n\n{'='*80}") print(f"ABLATION STUDY RESULTS (v2 — with OGM-GE anti-collapse)") print(f"{'='*80}") print(f"{'Experiment':<42s} {'Accuracy':>10s} {'Macro-F1':>10s} {'Weighted-F1':>12s}") print("-" * 76) for name, r in all_results.items(): if "error" in r: print(f"{name:<42s} {'ERROR':>10s}") else: m = r["final_metrics"] print(f"{name:<42s} {m['accuracy']:>10.4f} {m['macro_f1']:>10.4f} {m['weighted_f1']:>12.4f}") # Robustness comparison (if available) print(f"\n{'='*80}") print(f"MISSING MODALITY ROBUSTNESS") print(f"{'='*80}") print(f"{'Experiment':<42s} {'Both':>10s} {'Vis-Only':>10s} {'Aud-Only':>10s}") print("-" * 76) for name, r in all_results.items(): if "robustness" in r: rb = r["robustness"] print(f"{name:<42s} " f"{rb.get('both_modalities', {}).get('macro_f1', 0):>10.4f} " f"{rb.get('visual_only', {}).get('macro_f1', 0):>10.4f} " f"{rb.get('audio_only', {}).get('macro_f1', 0):>10.4f}") # v1 vs v2 comparison reminder print(f"\n{'='*80}") print("v1 BASELINE (for comparison):") print(" Both modalities: Acc=0.9423 F1=0.9481") print(" Audio only: Acc=0.9231 F1=0.9316") print(" Visual only: Acc=0.2308 F1=0.0750 ← MODALITY COLLAPSE") print(f"{'='*80}") # Save summary os.makedirs("./results", exist_ok=True) with open("./results/ablation_summary_v2.json", "w") as f: json.dump(all_results, f, indent=2) print(f"\nFull results saved to ./results/ablation_summary_v2.json") return all_results if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--quick_test", action="store_true") run_ablation_suite(quick_test=parser.parse_args().quick_test)