Ellaft's picture
Overwrite config.py with v2 config (OGM-GE params, asymmetric LR, auxiliary loss weights)
d1b8db1 verified
"""
Multimodal PC Fault Detection - Configuration
==============================================
Hyperparameters, dataset configs, experiment settings, and anti-modality-collapse params.
Includes OGM-GE (Peng et al., CVPR 2022) hyperparameters, auxiliary loss weights,
and asymmetric learning rate multipliers.
"""
from dataclasses import dataclass, field
from typing import List, Optional, Literal
# ============================================================================
# Fault Taxonomy
# ============================================================================
FAULT_CLASSES = [
"normal_operation", # 0: Normal PC running
"boot_failure", # 1: BIOS beep codes, POST failures
"overheating_fan", # 2: Fan noise anomalies, thermal warnings
"storage_failure", # 3: HDD clicking, drive errors
"system_crash", # 4: BSOD, system halts
]
NUM_CLASSES = len(FAULT_CLASSES)
@dataclass
class DataConfig:
sample_rate: int = 32000
audio_duration: float = 5.0
n_fft: int = 1024
hop_length: int = 320
n_mels: int = 64
fmin: int = 50
fmax: int = 14000
image_size: int = 224
audio_noise_snr_db: float = 10.0
time_shift_max: float = 0.2
freq_mask_max: int = 10
time_mask_max: int = 20
@dataclass
class ModelConfig:
vit_model_name: str = "google/vit-base-patch16-224-in21k"
vit_embed_dim: int = 768
ast_model_name: str = "MIT/ast-finetuned-audioset-10-10-0.4593"
ast_embed_dim: int = 768
fusion_type: Literal["concat", "weighted_sum", "attention"] = "concat"
fusion_dim: int = 512
fusion_dropout: float = 0.3
num_classes: int = NUM_CLASSES
modality_dropout_p: float = 0.3
@dataclass
class LoRAConfig:
enabled: bool = True
r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.1
bias: str = "none"
vit_target_modules: List[str] = field(default_factory=lambda: ["query", "value"])
vit_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"])
ast_target_modules: List[str] = field(default_factory=lambda: ["query", "value"])
ast_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"])
@dataclass
class TrainConfig:
mode: Literal["multimodal", "visual_only", "audio_only"] = "multimodal"
finetune_method: Literal["lora", "full", "linear_probe"] = "lora"
learning_rate: float = 5e-4
lora_learning_rate: float = 5e-3
weight_decay: float = 0.01
warmup_ratio: float = 0.1
max_grad_norm: float = 1.0
num_epochs: int = 15
per_device_train_batch_size: int = 16
per_device_eval_batch_size: int = 32
gradient_accumulation_steps: int = 2
fp16: bool = True
eval_strategy: str = "epoch"
metric_for_best_model: str = "macro_f1"
output_dir: str = "./results"
push_to_hub: bool = True
hub_model_id: str = "Ellaft/multimodal-pc-fault-detector"
save_strategy: str = "epoch"
save_total_limit: int = 3
load_best_model_at_end: bool = True
logging_steps: int = 10
logging_strategy: str = "steps"
logging_first_step: bool = True
disable_tqdm: bool = True
seed: int = 42
@dataclass
class ExperimentConfig:
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
lora: LoRAConfig = field(default_factory=LoRAConfig)
train: TrainConfig = field(default_factory=TrainConfig)
experiment_name: str = "multimodal_pc_fault"
description: str = "Two-branch audio-visual fusion for PC fault detection with OGM-GE"
# Anti-modality-collapse hyperparameters
# OGM-GE: Peng et al., CVPR 2022 (arXiv: 2203.15332)
ogm_alpha: float = 0.3
ogm_noise_sigma: float = 0.1
lambda_visual: float = 1.5 # Boost weak visual branch
lambda_audio: float = 0.5 # Dampen dominant audio branch
visual_lr_multiplier: float = 3.0
audio_lr_multiplier: float = 0.5
# Dataset
hub_dataset: str = "Ellaft/pc-fault-real-dataset"
def get_ablation_configs():
configs = {}
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.experiment_name = "multimodal_lora"
configs["multimodal_lora"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "visual_only"
cfg.experiment_name = "visual_only_lora"
cfg.train.hub_model_id = "Ellaft/pc-fault-visual-only"
configs["visual_only_lora"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "audio_only"
cfg.experiment_name = "audio_only_lora"
cfg.train.hub_model_id = "Ellaft/pc-fault-audio-only"
configs["audio_only_lora"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.train.finetune_method = "full"
cfg.lora.enabled = False
cfg.train.learning_rate = 2e-5
cfg.experiment_name = "multimodal_full_ft"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-full-ft"
configs["multimodal_full_ft"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.train.finetune_method = "linear_probe"
cfg.lora.enabled = False
cfg.train.learning_rate = 1e-3
cfg.experiment_name = "multimodal_linear_probe"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-linear-probe"
configs["multimodal_linear_probe"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.model.modality_dropout_p = 0.5
cfg.experiment_name = "multimodal_robust"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-robust"
configs["multimodal_robust"] = cfg
return configs