File size: 5,111 Bytes
34b5fb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
Multimodal PC Fault Detection - Ablation Runner v2
====================================================
Runs all ablation experiments using train_v2.py with OGM-GE.
Usage: python run_ablations_v2.py [--quick_test]
"""
import os, sys, json, subprocess
from config import FAULT_CLASSES
def run_ablation_suite(quick_test=False):
experiments = [
{
"name": "Multimodal + LoRA + OGM-GE",
"args": [
"--mode", "multimodal", "--finetune", "lora",
"--fusion", "concat", "--eval_robustness",
"--output_dir", "./results/multimodal_lora_ogmge",
"--hub_model_id", "Ellaft/pc-fault-multimodal-lora-ogmge",
],
},
{
"name": "Visual Only + LoRA",
"args": [
"--mode", "visual_only", "--finetune", "lora",
"--output_dir", "./results/visual_only_lora",
"--hub_model_id", "Ellaft/pc-fault-visual-only",
],
},
{
"name": "Audio Only + LoRA",
"args": [
"--mode", "audio_only", "--finetune", "lora",
"--output_dir", "./results/audio_only_lora",
"--hub_model_id", "Ellaft/pc-fault-audio-only",
],
},
{
"name": "Multimodal + Full FT + OGM-GE",
"args": [
"--mode", "multimodal", "--finetune", "full",
"--lr", "2e-5", "--fusion", "concat",
"--output_dir", "./results/multimodal_full_ft_ogmge",
"--hub_model_id", "Ellaft/pc-fault-multimodal-full-ft",
],
},
{
"name": "Multimodal + Linear Probe + OGM-GE",
"args": [
"--mode", "multimodal", "--finetune", "linear_probe",
"--lr", "1e-3", "--fusion", "concat",
"--output_dir", "./results/multimodal_linear_probe_ogmge",
"--hub_model_id", "Ellaft/pc-fault-multimodal-linear-probe",
],
},
{
"name": "Multimodal + LoRA + High Dropout + OGM-GE",
"args": [
"--mode", "multimodal", "--finetune", "lora",
"--modality_dropout", "0.5", "--fusion", "concat",
"--eval_robustness",
"--output_dir", "./results/multimodal_robust_ogmge",
"--hub_model_id", "Ellaft/pc-fault-multimodal-robust",
],
},
]
if quick_test:
for exp in experiments:
exp["args"].append("--quick_test")
all_results = {}
for exp in experiments:
print(f"\n{'='*60}")
print(f"Running: {exp['name']}")
print(f"{'='*60}")
result = subprocess.run(
[sys.executable, "train_v2.py"] + exp["args"],
capture_output=False)
results_dir = exp["args"][exp["args"].index("--output_dir") + 1]
rpath = os.path.join(results_dir, "results_v2.json")
if os.path.exists(rpath):
all_results[exp["name"]] = json.load(open(rpath))
else:
all_results[exp["name"]] = {"error": f"rc={result.returncode}"}
# Summary table
print(f"\n\n{'='*80}")
print(f"ABLATION STUDY RESULTS (v2 — with OGM-GE anti-collapse)")
print(f"{'='*80}")
print(f"{'Experiment':<42s} {'Accuracy':>10s} {'Macro-F1':>10s} {'Weighted-F1':>12s}")
print("-" * 76)
for name, r in all_results.items():
if "error" in r:
print(f"{name:<42s} {'ERROR':>10s}")
else:
m = r["final_metrics"]
print(f"{name:<42s} {m['accuracy']:>10.4f} {m['macro_f1']:>10.4f} {m['weighted_f1']:>12.4f}")
# Robustness comparison (if available)
print(f"\n{'='*80}")
print(f"MISSING MODALITY ROBUSTNESS")
print(f"{'='*80}")
print(f"{'Experiment':<42s} {'Both':>10s} {'Vis-Only':>10s} {'Aud-Only':>10s}")
print("-" * 76)
for name, r in all_results.items():
if "robustness" in r:
rb = r["robustness"]
print(f"{name:<42s} "
f"{rb.get('both_modalities', {}).get('macro_f1', 0):>10.4f} "
f"{rb.get('visual_only', {}).get('macro_f1', 0):>10.4f} "
f"{rb.get('audio_only', {}).get('macro_f1', 0):>10.4f}")
# v1 vs v2 comparison reminder
print(f"\n{'='*80}")
print("v1 BASELINE (for comparison):")
print(" Both modalities: Acc=0.9423 F1=0.9481")
print(" Audio only: Acc=0.9231 F1=0.9316")
print(" Visual only: Acc=0.2308 F1=0.0750 ← MODALITY COLLAPSE")
print(f"{'='*80}")
# Save summary
os.makedirs("./results", exist_ok=True)
with open("./results/ablation_summary_v2.json", "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nFull results saved to ./results/ablation_summary_v2.json")
return all_results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--quick_test", action="store_true")
run_ablation_suite(quick_test=parser.parse_args().quick_test)
|