""" Multimodal PC Fault Detection - Configuration v2 ================================================== All hyperparameters, dataset configs, and experiment settings in one place. v2 additions: - OGM-GE hyperparameters (ogm_alpha, ogm_noise_sigma) - Auxiliary loss weights (lambda_visual, lambda_audio) - Asymmetric LR multipliers (visual_lr_multiplier, audio_lr_multiplier) """ from dataclasses import dataclass, field from typing import List, Optional, Literal # ============================================================================ # Fault Taxonomy # ============================================================================ FAULT_CLASSES = [ "normal_operation", # 0: Normal PC running "boot_failure", # 1: BIOS beep codes, POST failures "overheating_fan", # 2: Fan noise anomalies, thermal warnings "storage_failure", # 3: HDD clicking, drive errors "system_crash", # 4: BSOD, system halts ] NUM_CLASSES = len(FAULT_CLASSES) # ============================================================================ # ESC-50 → PC Fault Class Mapping (Audio) — kept for backward compat # ============================================================================ ESC50_TO_FAULT = { "keyboard_typing": 0, "mouse_click": 0, "clock_alarm": 1, "siren": 1, "vacuum_cleaner": 2, "engine": 2, "washing_machine": 2, "clock_tick": 3, "door_wood_knock": 3, "hand_saw": 3, "glass_breaking": 4, "fireworks": 4, "chainsaw": 4, } ESC50_CATEGORY_TO_TARGET = { "keyboard_typing": 32, "mouse_click": 33, "clock_alarm": 37, "siren": 42, "vacuum_cleaner": 36, "engine": 44, "washing_machine": 35, "clock_tick": 38, "door_wood_knock": 30, "hand_saw": 49, "glass_breaking": 39, "fireworks": 48, "chainsaw": 41, } # ============================================================================ # Visual Data Synthesis Config — kept for backward compat # ============================================================================ VISUAL_SYNTHESIS = { "normal_operation": { "description": "Clean desktop, green status indicators, normal task manager", "color_dominant": (0, 128, 0), "text_overlay": ["System OK", "All services running", "Temperature: Normal"], }, "boot_failure": { "description": "BIOS POST screen with error codes, black/blue background", "color_dominant": (0, 0, 0), "text_overlay": ["BIOS ERROR", "POST Code: 3-3-1", "Memory Test Failed"], }, "overheating_fan": { "description": "Temperature warning, red thermal display, CPU throttling", "color_dominant": (255, 0, 0), "text_overlay": ["CRITICAL TEMP", "CPU: 98°C", "Thermal Throttling Active"], }, "storage_failure": { "description": "Disk error screen, SMART warning, data recovery prompt", "color_dominant": (255, 165, 0), "text_overlay": ["DISK ERROR", "S.M.A.R.T. WARNING", "Sector Read Failure"], }, "system_crash": { "description": "Blue screen of death, kernel panic, stop code", "color_dominant": (0, 120, 215), "text_overlay": ["STOP: 0x0000007E", "SYSTEM_THREAD_EXCEPTION_NOT_HANDLED", "Your PC ran into a problem"], }, } @dataclass class DataConfig: esc50_dataset: str = "ashraq/esc50" audioset_dataset: str = "agkphysics/AudioSet" audioset_config: str = "balanced" sample_rate: int = 32000 audio_duration: float = 5.0 n_fft: int = 1024 hop_length: int = 320 n_mels: int = 64 fmin: int = 50 fmax: int = 14000 image_size: int = 224 val_fold: int = 5 num_synthetic_per_class: int = 200 audio_noise_snr_db: float = 10.0 time_shift_max: float = 0.2 freq_mask_max: int = 10 time_mask_max: int = 20 @dataclass class ModelConfig: vit_model_name: str = "google/vit-base-patch16-224-in21k" vit_embed_dim: int = 768 ast_model_name: str = "MIT/ast-finetuned-audioset-10-10-0.4593" ast_embed_dim: int = 768 fusion_type: Literal["concat", "weighted_sum", "attention"] = "concat" fusion_dim: int = 512 fusion_dropout: float = 0.3 num_classes: int = NUM_CLASSES modality_dropout_p: float = 0.3 @dataclass class LoRAConfig: enabled: bool = True r: int = 8 lora_alpha: int = 16 lora_dropout: float = 0.1 bias: str = "none" vit_target_modules: List[str] = field(default_factory=lambda: ["query", "value"]) vit_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"]) ast_target_modules: List[str] = field(default_factory=lambda: ["query", "value"]) ast_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"]) @dataclass class TrainConfig: mode: Literal["multimodal", "visual_only", "audio_only"] = "multimodal" finetune_method: Literal["lora", "full", "linear_probe"] = "lora" learning_rate: float = 5e-4 lora_learning_rate: float = 5e-3 weight_decay: float = 0.01 warmup_ratio: float = 0.1 max_grad_norm: float = 1.0 num_epochs: int = 15 per_device_train_batch_size: int = 16 per_device_eval_batch_size: int = 32 gradient_accumulation_steps: int = 2 fp16: bool = True eval_strategy: str = "epoch" metric_for_best_model: str = "macro_f1" output_dir: str = "./results" push_to_hub: bool = True hub_model_id: str = "Ellaft/multimodal-pc-fault-detector" save_strategy: str = "epoch" save_total_limit: int = 3 load_best_model_at_end: bool = True logging_steps: int = 10 logging_strategy: str = "steps" logging_first_step: bool = True disable_tqdm: bool = True seed: int = 42 @dataclass class ExperimentConfig: data: DataConfig = field(default_factory=DataConfig) model: ModelConfig = field(default_factory=ModelConfig) lora: LoRAConfig = field(default_factory=LoRAConfig) train: TrainConfig = field(default_factory=TrainConfig) experiment_name: str = "multimodal_pc_fault_v2" description: str = "Two-branch audio-visual fusion for PC fault detection with OGM-GE anti-collapse" # ==================================================================== # v2: Anti-modality-collapse hyperparameters # ==================================================================== # OGM-GE: On-the-fly Gradient Modulation + Generalization Enhancement # From Peng et al., CVPR 2022 (arXiv: 2203.15332) ogm_alpha: float = 0.3 # Modulation strength (paper: 0.3-0.5) ogm_noise_sigma: float = 0.1 # GE noise std (paper: 0.1) # Auxiliary unimodal loss weights # Asymmetric: boost visual (weak modality), dampen audio (dominant) lambda_visual: float = 1.5 # Weight for visual auxiliary loss lambda_audio: float = 0.5 # Weight for audio auxiliary loss # Asymmetric learning rate multipliers (applied to lora_learning_rate) visual_lr_multiplier: float = 3.0 # Visual gets 3x base LR audio_lr_multiplier: float = 0.5 # Audio gets 0.5x base LR def get_ablation_configs(): """ Generate ablation experiment configurations. v2 experiments: 1. Multimodal + LoRA + OGM-GE (the full v2 pipeline) 2. Visual Only + LoRA (unimodal baseline) 3. Audio Only + LoRA (unimodal baseline) 4. Multimodal + Full FT + OGM-GE 5. Multimodal + Linear Probe + OGM-GE 6. Multimodal + LoRA + High Dropout + OGM-GE (robustness) """ configs = {} # 1. Multimodal + LoRA + OGM-GE (PRIMARY) cfg = ExperimentConfig() cfg.train.mode = "multimodal" cfg.experiment_name = "multimodal_lora_ogmge" configs["multimodal_lora_ogmge"] = cfg # 2. Visual Only + LoRA cfg = ExperimentConfig() cfg.train.mode = "visual_only" cfg.experiment_name = "visual_only_lora" cfg.train.hub_model_id = "Ellaft/pc-fault-visual-only" configs["visual_only_lora"] = cfg # 3. Audio Only + LoRA cfg = ExperimentConfig() cfg.train.mode = "audio_only" cfg.experiment_name = "audio_only_lora" cfg.train.hub_model_id = "Ellaft/pc-fault-audio-only" configs["audio_only_lora"] = cfg # 4. Multimodal + Full FT + OGM-GE cfg = ExperimentConfig() cfg.train.mode = "multimodal" cfg.train.finetune_method = "full" cfg.lora.enabled = False cfg.train.learning_rate = 2e-5 cfg.experiment_name = "multimodal_full_ft_ogmge" cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-full-ft" configs["multimodal_full_ft_ogmge"] = cfg # 5. Multimodal + Linear Probe + OGM-GE cfg = ExperimentConfig() cfg.train.mode = "multimodal" cfg.train.finetune_method = "linear_probe" cfg.lora.enabled = False cfg.train.learning_rate = 1e-3 cfg.experiment_name = "multimodal_linear_probe_ogmge" cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-linear-probe" configs["multimodal_linear_probe_ogmge"] = cfg # 6. Multimodal + LoRA + High Dropout + OGM-GE cfg = ExperimentConfig() cfg.train.mode = "multimodal" cfg.model.modality_dropout_p = 0.5 cfg.experiment_name = "multimodal_robust_ogmge" cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-robust" configs["multimodal_robust_ogmge"] = cfg return configs