Ellaft's picture
Add config_v2.py: OGM-GE hyperparameters + asymmetric LR config
d077f18 verified
"""
Multimodal PC Fault Detection - Configuration v2
==================================================
All hyperparameters, dataset configs, and experiment settings in one place.
v2 additions:
- OGM-GE hyperparameters (ogm_alpha, ogm_noise_sigma)
- Auxiliary loss weights (lambda_visual, lambda_audio)
- Asymmetric LR multipliers (visual_lr_multiplier, audio_lr_multiplier)
"""
from dataclasses import dataclass, field
from typing import List, Optional, Literal
# ============================================================================
# Fault Taxonomy
# ============================================================================
FAULT_CLASSES = [
"normal_operation", # 0: Normal PC running
"boot_failure", # 1: BIOS beep codes, POST failures
"overheating_fan", # 2: Fan noise anomalies, thermal warnings
"storage_failure", # 3: HDD clicking, drive errors
"system_crash", # 4: BSOD, system halts
]
NUM_CLASSES = len(FAULT_CLASSES)
# ============================================================================
# ESC-50 → PC Fault Class Mapping (Audio) — kept for backward compat
# ============================================================================
ESC50_TO_FAULT = {
"keyboard_typing": 0, "mouse_click": 0,
"clock_alarm": 1, "siren": 1,
"vacuum_cleaner": 2, "engine": 2, "washing_machine": 2,
"clock_tick": 3, "door_wood_knock": 3, "hand_saw": 3,
"glass_breaking": 4, "fireworks": 4, "chainsaw": 4,
}
ESC50_CATEGORY_TO_TARGET = {
"keyboard_typing": 32, "mouse_click": 33, "clock_alarm": 37,
"siren": 42, "vacuum_cleaner": 36, "engine": 44,
"washing_machine": 35, "clock_tick": 38, "door_wood_knock": 30,
"hand_saw": 49, "glass_breaking": 39, "fireworks": 48, "chainsaw": 41,
}
# ============================================================================
# Visual Data Synthesis Config — kept for backward compat
# ============================================================================
VISUAL_SYNTHESIS = {
"normal_operation": {
"description": "Clean desktop, green status indicators, normal task manager",
"color_dominant": (0, 128, 0),
"text_overlay": ["System OK", "All services running", "Temperature: Normal"],
},
"boot_failure": {
"description": "BIOS POST screen with error codes, black/blue background",
"color_dominant": (0, 0, 0),
"text_overlay": ["BIOS ERROR", "POST Code: 3-3-1", "Memory Test Failed"],
},
"overheating_fan": {
"description": "Temperature warning, red thermal display, CPU throttling",
"color_dominant": (255, 0, 0),
"text_overlay": ["CRITICAL TEMP", "CPU: 98°C", "Thermal Throttling Active"],
},
"storage_failure": {
"description": "Disk error screen, SMART warning, data recovery prompt",
"color_dominant": (255, 165, 0),
"text_overlay": ["DISK ERROR", "S.M.A.R.T. WARNING", "Sector Read Failure"],
},
"system_crash": {
"description": "Blue screen of death, kernel panic, stop code",
"color_dominant": (0, 120, 215),
"text_overlay": ["STOP: 0x0000007E", "SYSTEM_THREAD_EXCEPTION_NOT_HANDLED",
"Your PC ran into a problem"],
},
}
@dataclass
class DataConfig:
esc50_dataset: str = "ashraq/esc50"
audioset_dataset: str = "agkphysics/AudioSet"
audioset_config: str = "balanced"
sample_rate: int = 32000
audio_duration: float = 5.0
n_fft: int = 1024
hop_length: int = 320
n_mels: int = 64
fmin: int = 50
fmax: int = 14000
image_size: int = 224
val_fold: int = 5
num_synthetic_per_class: int = 200
audio_noise_snr_db: float = 10.0
time_shift_max: float = 0.2
freq_mask_max: int = 10
time_mask_max: int = 20
@dataclass
class ModelConfig:
vit_model_name: str = "google/vit-base-patch16-224-in21k"
vit_embed_dim: int = 768
ast_model_name: str = "MIT/ast-finetuned-audioset-10-10-0.4593"
ast_embed_dim: int = 768
fusion_type: Literal["concat", "weighted_sum", "attention"] = "concat"
fusion_dim: int = 512
fusion_dropout: float = 0.3
num_classes: int = NUM_CLASSES
modality_dropout_p: float = 0.3
@dataclass
class LoRAConfig:
enabled: bool = True
r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.1
bias: str = "none"
vit_target_modules: List[str] = field(default_factory=lambda: ["query", "value"])
vit_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"])
ast_target_modules: List[str] = field(default_factory=lambda: ["query", "value"])
ast_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"])
@dataclass
class TrainConfig:
mode: Literal["multimodal", "visual_only", "audio_only"] = "multimodal"
finetune_method: Literal["lora", "full", "linear_probe"] = "lora"
learning_rate: float = 5e-4
lora_learning_rate: float = 5e-3
weight_decay: float = 0.01
warmup_ratio: float = 0.1
max_grad_norm: float = 1.0
num_epochs: int = 15
per_device_train_batch_size: int = 16
per_device_eval_batch_size: int = 32
gradient_accumulation_steps: int = 2
fp16: bool = True
eval_strategy: str = "epoch"
metric_for_best_model: str = "macro_f1"
output_dir: str = "./results"
push_to_hub: bool = True
hub_model_id: str = "Ellaft/multimodal-pc-fault-detector"
save_strategy: str = "epoch"
save_total_limit: int = 3
load_best_model_at_end: bool = True
logging_steps: int = 10
logging_strategy: str = "steps"
logging_first_step: bool = True
disable_tqdm: bool = True
seed: int = 42
@dataclass
class ExperimentConfig:
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
lora: LoRAConfig = field(default_factory=LoRAConfig)
train: TrainConfig = field(default_factory=TrainConfig)
experiment_name: str = "multimodal_pc_fault_v2"
description: str = "Two-branch audio-visual fusion for PC fault detection with OGM-GE anti-collapse"
# ====================================================================
# v2: Anti-modality-collapse hyperparameters
# ====================================================================
# OGM-GE: On-the-fly Gradient Modulation + Generalization Enhancement
# From Peng et al., CVPR 2022 (arXiv: 2203.15332)
ogm_alpha: float = 0.3 # Modulation strength (paper: 0.3-0.5)
ogm_noise_sigma: float = 0.1 # GE noise std (paper: 0.1)
# Auxiliary unimodal loss weights
# Asymmetric: boost visual (weak modality), dampen audio (dominant)
lambda_visual: float = 1.5 # Weight for visual auxiliary loss
lambda_audio: float = 0.5 # Weight for audio auxiliary loss
# Asymmetric learning rate multipliers (applied to lora_learning_rate)
visual_lr_multiplier: float = 3.0 # Visual gets 3x base LR
audio_lr_multiplier: float = 0.5 # Audio gets 0.5x base LR
def get_ablation_configs():
"""
Generate ablation experiment configurations.
v2 experiments:
1. Multimodal + LoRA + OGM-GE (the full v2 pipeline)
2. Visual Only + LoRA (unimodal baseline)
3. Audio Only + LoRA (unimodal baseline)
4. Multimodal + Full FT + OGM-GE
5. Multimodal + Linear Probe + OGM-GE
6. Multimodal + LoRA + High Dropout + OGM-GE (robustness)
"""
configs = {}
# 1. Multimodal + LoRA + OGM-GE (PRIMARY)
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.experiment_name = "multimodal_lora_ogmge"
configs["multimodal_lora_ogmge"] = cfg
# 2. Visual Only + LoRA
cfg = ExperimentConfig()
cfg.train.mode = "visual_only"
cfg.experiment_name = "visual_only_lora"
cfg.train.hub_model_id = "Ellaft/pc-fault-visual-only"
configs["visual_only_lora"] = cfg
# 3. Audio Only + LoRA
cfg = ExperimentConfig()
cfg.train.mode = "audio_only"
cfg.experiment_name = "audio_only_lora"
cfg.train.hub_model_id = "Ellaft/pc-fault-audio-only"
configs["audio_only_lora"] = cfg
# 4. Multimodal + Full FT + OGM-GE
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.train.finetune_method = "full"
cfg.lora.enabled = False
cfg.train.learning_rate = 2e-5
cfg.experiment_name = "multimodal_full_ft_ogmge"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-full-ft"
configs["multimodal_full_ft_ogmge"] = cfg
# 5. Multimodal + Linear Probe + OGM-GE
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.train.finetune_method = "linear_probe"
cfg.lora.enabled = False
cfg.train.learning_rate = 1e-3
cfg.experiment_name = "multimodal_linear_probe_ogmge"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-linear-probe"
configs["multimodal_linear_probe_ogmge"] = cfg
# 6. Multimodal + LoRA + High Dropout + OGM-GE
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.model.modality_dropout_p = 0.5
cfg.experiment_name = "multimodal_robust_ogmge"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-robust"
configs["multimodal_robust_ogmge"] = cfg
return configs