multimodal-pc-fault-detector / src /run_ablations_v2.py
Ellaft's picture
Add run_ablations_v2.py: v2 ablation runner with OGM-GE experiments
34b5fb4 verified
"""
Multimodal PC Fault Detection - Ablation Runner v2
====================================================
Runs all ablation experiments using train_v2.py with OGM-GE.
Usage: python run_ablations_v2.py [--quick_test]
"""
import os, sys, json, subprocess
from config import FAULT_CLASSES
def run_ablation_suite(quick_test=False):
experiments = [
{
"name": "Multimodal + LoRA + OGM-GE",
"args": [
"--mode", "multimodal", "--finetune", "lora",
"--fusion", "concat", "--eval_robustness",
"--output_dir", "./results/multimodal_lora_ogmge",
"--hub_model_id", "Ellaft/pc-fault-multimodal-lora-ogmge",
],
},
{
"name": "Visual Only + LoRA",
"args": [
"--mode", "visual_only", "--finetune", "lora",
"--output_dir", "./results/visual_only_lora",
"--hub_model_id", "Ellaft/pc-fault-visual-only",
],
},
{
"name": "Audio Only + LoRA",
"args": [
"--mode", "audio_only", "--finetune", "lora",
"--output_dir", "./results/audio_only_lora",
"--hub_model_id", "Ellaft/pc-fault-audio-only",
],
},
{
"name": "Multimodal + Full FT + OGM-GE",
"args": [
"--mode", "multimodal", "--finetune", "full",
"--lr", "2e-5", "--fusion", "concat",
"--output_dir", "./results/multimodal_full_ft_ogmge",
"--hub_model_id", "Ellaft/pc-fault-multimodal-full-ft",
],
},
{
"name": "Multimodal + Linear Probe + OGM-GE",
"args": [
"--mode", "multimodal", "--finetune", "linear_probe",
"--lr", "1e-3", "--fusion", "concat",
"--output_dir", "./results/multimodal_linear_probe_ogmge",
"--hub_model_id", "Ellaft/pc-fault-multimodal-linear-probe",
],
},
{
"name": "Multimodal + LoRA + High Dropout + OGM-GE",
"args": [
"--mode", "multimodal", "--finetune", "lora",
"--modality_dropout", "0.5", "--fusion", "concat",
"--eval_robustness",
"--output_dir", "./results/multimodal_robust_ogmge",
"--hub_model_id", "Ellaft/pc-fault-multimodal-robust",
],
},
]
if quick_test:
for exp in experiments:
exp["args"].append("--quick_test")
all_results = {}
for exp in experiments:
print(f"\n{'='*60}")
print(f"Running: {exp['name']}")
print(f"{'='*60}")
result = subprocess.run(
[sys.executable, "train_v2.py"] + exp["args"],
capture_output=False)
results_dir = exp["args"][exp["args"].index("--output_dir") + 1]
rpath = os.path.join(results_dir, "results_v2.json")
if os.path.exists(rpath):
all_results[exp["name"]] = json.load(open(rpath))
else:
all_results[exp["name"]] = {"error": f"rc={result.returncode}"}
# Summary table
print(f"\n\n{'='*80}")
print(f"ABLATION STUDY RESULTS (v2 — with OGM-GE anti-collapse)")
print(f"{'='*80}")
print(f"{'Experiment':<42s} {'Accuracy':>10s} {'Macro-F1':>10s} {'Weighted-F1':>12s}")
print("-" * 76)
for name, r in all_results.items():
if "error" in r:
print(f"{name:<42s} {'ERROR':>10s}")
else:
m = r["final_metrics"]
print(f"{name:<42s} {m['accuracy']:>10.4f} {m['macro_f1']:>10.4f} {m['weighted_f1']:>12.4f}")
# Robustness comparison (if available)
print(f"\n{'='*80}")
print(f"MISSING MODALITY ROBUSTNESS")
print(f"{'='*80}")
print(f"{'Experiment':<42s} {'Both':>10s} {'Vis-Only':>10s} {'Aud-Only':>10s}")
print("-" * 76)
for name, r in all_results.items():
if "robustness" in r:
rb = r["robustness"]
print(f"{name:<42s} "
f"{rb.get('both_modalities', {}).get('macro_f1', 0):>10.4f} "
f"{rb.get('visual_only', {}).get('macro_f1', 0):>10.4f} "
f"{rb.get('audio_only', {}).get('macro_f1', 0):>10.4f}")
# v1 vs v2 comparison reminder
print(f"\n{'='*80}")
print("v1 BASELINE (for comparison):")
print(" Both modalities: Acc=0.9423 F1=0.9481")
print(" Audio only: Acc=0.9231 F1=0.9316")
print(" Visual only: Acc=0.2308 F1=0.0750 ← MODALITY COLLAPSE")
print(f"{'='*80}")
# Save summary
os.makedirs("./results", exist_ok=True)
with open("./results/ablation_summary_v2.json", "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nFull results saved to ./results/ablation_summary_v2.json")
return all_results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--quick_test", action="store_true")
run_ablation_suite(quick_test=parser.parse_args().quick_test)