| """ |
| Multimodal PC Fault Detection - Configuration |
| ============================================== |
| Hyperparameters, dataset configs, experiment settings, and anti-modality-collapse params. |
| |
| Includes OGM-GE (Peng et al., CVPR 2022) hyperparameters, auxiliary loss weights, |
| and asymmetric learning rate multipliers. |
| """ |
|
|
| from dataclasses import dataclass, field |
| from typing import List, Optional, Literal |
|
|
|
|
| |
| |
| |
| FAULT_CLASSES = [ |
| "normal_operation", |
| "boot_failure", |
| "overheating_fan", |
| "storage_failure", |
| "system_crash", |
| ] |
| NUM_CLASSES = len(FAULT_CLASSES) |
|
|
|
|
| @dataclass |
| class DataConfig: |
| sample_rate: int = 32000 |
| audio_duration: float = 5.0 |
| n_fft: int = 1024 |
| hop_length: int = 320 |
| n_mels: int = 64 |
| fmin: int = 50 |
| fmax: int = 14000 |
| image_size: int = 224 |
| audio_noise_snr_db: float = 10.0 |
| time_shift_max: float = 0.2 |
| freq_mask_max: int = 10 |
| time_mask_max: int = 20 |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| vit_model_name: str = "google/vit-base-patch16-224-in21k" |
| vit_embed_dim: int = 768 |
| ast_model_name: str = "MIT/ast-finetuned-audioset-10-10-0.4593" |
| ast_embed_dim: int = 768 |
| fusion_type: Literal["concat", "weighted_sum", "attention"] = "concat" |
| fusion_dim: int = 512 |
| fusion_dropout: float = 0.3 |
| num_classes: int = NUM_CLASSES |
| modality_dropout_p: float = 0.3 |
|
|
|
|
| @dataclass |
| class LoRAConfig: |
| enabled: bool = True |
| r: int = 8 |
| lora_alpha: int = 16 |
| lora_dropout: float = 0.1 |
| bias: str = "none" |
| vit_target_modules: List[str] = field(default_factory=lambda: ["query", "value"]) |
| vit_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"]) |
| ast_target_modules: List[str] = field(default_factory=lambda: ["query", "value"]) |
| ast_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"]) |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| mode: Literal["multimodal", "visual_only", "audio_only"] = "multimodal" |
| finetune_method: Literal["lora", "full", "linear_probe"] = "lora" |
| learning_rate: float = 5e-4 |
| lora_learning_rate: float = 5e-3 |
| weight_decay: float = 0.01 |
| warmup_ratio: float = 0.1 |
| max_grad_norm: float = 1.0 |
| num_epochs: int = 15 |
| per_device_train_batch_size: int = 16 |
| per_device_eval_batch_size: int = 32 |
| gradient_accumulation_steps: int = 2 |
| fp16: bool = True |
| eval_strategy: str = "epoch" |
| metric_for_best_model: str = "macro_f1" |
| output_dir: str = "./results" |
| push_to_hub: bool = True |
| hub_model_id: str = "Ellaft/multimodal-pc-fault-detector" |
| save_strategy: str = "epoch" |
| save_total_limit: int = 3 |
| load_best_model_at_end: bool = True |
| logging_steps: int = 10 |
| logging_strategy: str = "steps" |
| logging_first_step: bool = True |
| disable_tqdm: bool = True |
| seed: int = 42 |
|
|
|
|
| @dataclass |
| class ExperimentConfig: |
| data: DataConfig = field(default_factory=DataConfig) |
| model: ModelConfig = field(default_factory=ModelConfig) |
| lora: LoRAConfig = field(default_factory=LoRAConfig) |
| train: TrainConfig = field(default_factory=TrainConfig) |
| experiment_name: str = "multimodal_pc_fault" |
| description: str = "Two-branch audio-visual fusion for PC fault detection with OGM-GE" |
|
|
| |
| |
| ogm_alpha: float = 0.3 |
| ogm_noise_sigma: float = 0.1 |
| lambda_visual: float = 1.5 |
| lambda_audio: float = 0.5 |
| visual_lr_multiplier: float = 3.0 |
| audio_lr_multiplier: float = 0.5 |
|
|
| |
| hub_dataset: str = "Ellaft/pc-fault-real-dataset" |
|
|
|
|
| def get_ablation_configs(): |
| configs = {} |
|
|
| cfg = ExperimentConfig() |
| cfg.train.mode = "multimodal" |
| cfg.experiment_name = "multimodal_lora" |
| configs["multimodal_lora"] = cfg |
|
|
| cfg = ExperimentConfig() |
| cfg.train.mode = "visual_only" |
| cfg.experiment_name = "visual_only_lora" |
| cfg.train.hub_model_id = "Ellaft/pc-fault-visual-only" |
| configs["visual_only_lora"] = cfg |
|
|
| cfg = ExperimentConfig() |
| cfg.train.mode = "audio_only" |
| cfg.experiment_name = "audio_only_lora" |
| cfg.train.hub_model_id = "Ellaft/pc-fault-audio-only" |
| configs["audio_only_lora"] = cfg |
|
|
| cfg = ExperimentConfig() |
| cfg.train.mode = "multimodal" |
| cfg.train.finetune_method = "full" |
| cfg.lora.enabled = False |
| cfg.train.learning_rate = 2e-5 |
| cfg.experiment_name = "multimodal_full_ft" |
| cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-full-ft" |
| configs["multimodal_full_ft"] = cfg |
|
|
| cfg = ExperimentConfig() |
| cfg.train.mode = "multimodal" |
| cfg.train.finetune_method = "linear_probe" |
| cfg.lora.enabled = False |
| cfg.train.learning_rate = 1e-3 |
| cfg.experiment_name = "multimodal_linear_probe" |
| cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-linear-probe" |
| configs["multimodal_linear_probe"] = cfg |
|
|
| cfg = ExperimentConfig() |
| cfg.train.mode = "multimodal" |
| cfg.model.modality_dropout_p = 0.5 |
| cfg.experiment_name = "multimodal_robust" |
| cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-robust" |
| configs["multimodal_robust"] = cfg |
|
|
| return configs |
|
|