File size: 5,543 Bytes
50fb488 d1b8db1 50fb488 d1b8db1 50fb488 d1b8db1 50fb488 d1b8db1 50fb488 d1b8db1 50fb488 d1b8db1 50fb488 d1b8db1 50fb488 d1b8db1 50fb488 d1b8db1 50fb488 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """
Multimodal PC Fault Detection - Configuration
==============================================
Hyperparameters, dataset configs, experiment settings, and anti-modality-collapse params.
Includes OGM-GE (Peng et al., CVPR 2022) hyperparameters, auxiliary loss weights,
and asymmetric learning rate multipliers.
"""
from dataclasses import dataclass, field
from typing import List, Optional, Literal
# ============================================================================
# Fault Taxonomy
# ============================================================================
FAULT_CLASSES = [
"normal_operation", # 0: Normal PC running
"boot_failure", # 1: BIOS beep codes, POST failures
"overheating_fan", # 2: Fan noise anomalies, thermal warnings
"storage_failure", # 3: HDD clicking, drive errors
"system_crash", # 4: BSOD, system halts
]
NUM_CLASSES = len(FAULT_CLASSES)
@dataclass
class DataConfig:
sample_rate: int = 32000
audio_duration: float = 5.0
n_fft: int = 1024
hop_length: int = 320
n_mels: int = 64
fmin: int = 50
fmax: int = 14000
image_size: int = 224
audio_noise_snr_db: float = 10.0
time_shift_max: float = 0.2
freq_mask_max: int = 10
time_mask_max: int = 20
@dataclass
class ModelConfig:
vit_model_name: str = "google/vit-base-patch16-224-in21k"
vit_embed_dim: int = 768
ast_model_name: str = "MIT/ast-finetuned-audioset-10-10-0.4593"
ast_embed_dim: int = 768
fusion_type: Literal["concat", "weighted_sum", "attention"] = "concat"
fusion_dim: int = 512
fusion_dropout: float = 0.3
num_classes: int = NUM_CLASSES
modality_dropout_p: float = 0.3
@dataclass
class LoRAConfig:
enabled: bool = True
r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.1
bias: str = "none"
vit_target_modules: List[str] = field(default_factory=lambda: ["query", "value"])
vit_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"])
ast_target_modules: List[str] = field(default_factory=lambda: ["query", "value"])
ast_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"])
@dataclass
class TrainConfig:
mode: Literal["multimodal", "visual_only", "audio_only"] = "multimodal"
finetune_method: Literal["lora", "full", "linear_probe"] = "lora"
learning_rate: float = 5e-4
lora_learning_rate: float = 5e-3
weight_decay: float = 0.01
warmup_ratio: float = 0.1
max_grad_norm: float = 1.0
num_epochs: int = 15
per_device_train_batch_size: int = 16
per_device_eval_batch_size: int = 32
gradient_accumulation_steps: int = 2
fp16: bool = True
eval_strategy: str = "epoch"
metric_for_best_model: str = "macro_f1"
output_dir: str = "./results"
push_to_hub: bool = True
hub_model_id: str = "Ellaft/multimodal-pc-fault-detector"
save_strategy: str = "epoch"
save_total_limit: int = 3
load_best_model_at_end: bool = True
logging_steps: int = 10
logging_strategy: str = "steps"
logging_first_step: bool = True
disable_tqdm: bool = True
seed: int = 42
@dataclass
class ExperimentConfig:
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
lora: LoRAConfig = field(default_factory=LoRAConfig)
train: TrainConfig = field(default_factory=TrainConfig)
experiment_name: str = "multimodal_pc_fault"
description: str = "Two-branch audio-visual fusion for PC fault detection with OGM-GE"
# Anti-modality-collapse hyperparameters
# OGM-GE: Peng et al., CVPR 2022 (arXiv: 2203.15332)
ogm_alpha: float = 0.3
ogm_noise_sigma: float = 0.1
lambda_visual: float = 1.5 # Boost weak visual branch
lambda_audio: float = 0.5 # Dampen dominant audio branch
visual_lr_multiplier: float = 3.0
audio_lr_multiplier: float = 0.5
# Dataset
hub_dataset: str = "Ellaft/pc-fault-real-dataset"
def get_ablation_configs():
configs = {}
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.experiment_name = "multimodal_lora"
configs["multimodal_lora"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "visual_only"
cfg.experiment_name = "visual_only_lora"
cfg.train.hub_model_id = "Ellaft/pc-fault-visual-only"
configs["visual_only_lora"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "audio_only"
cfg.experiment_name = "audio_only_lora"
cfg.train.hub_model_id = "Ellaft/pc-fault-audio-only"
configs["audio_only_lora"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.train.finetune_method = "full"
cfg.lora.enabled = False
cfg.train.learning_rate = 2e-5
cfg.experiment_name = "multimodal_full_ft"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-full-ft"
configs["multimodal_full_ft"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.train.finetune_method = "linear_probe"
cfg.lora.enabled = False
cfg.train.learning_rate = 1e-3
cfg.experiment_name = "multimodal_linear_probe"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-linear-probe"
configs["multimodal_linear_probe"] = cfg
cfg = ExperimentConfig()
cfg.train.mode = "multimodal"
cfg.model.modality_dropout_p = 0.5
cfg.experiment_name = "multimodal_robust"
cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-robust"
configs["multimodal_robust"] = cfg
return configs
|