File size: 5,543 Bytes
50fb488
 
 
d1b8db1
 
 
 
50fb488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1b8db1
 
 
 
 
 
 
 
 
 
 
 
 
 
50fb488
 
 
 
d1b8db1
50fb488
 
 
 
d1b8db1
50fb488
 
 
 
 
d1b8db1
50fb488
 
 
 
 
d1b8db1
50fb488
 
 
 
 
 
 
 
d1b8db1
50fb488
 
 
 
 
 
 
 
d1b8db1
50fb488
 
 
 
 
 
d1b8db1
50fb488
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Multimodal PC Fault Detection - Configuration
==============================================
Hyperparameters, dataset configs, experiment settings, and anti-modality-collapse params.

Includes OGM-GE (Peng et al., CVPR 2022) hyperparameters, auxiliary loss weights,
and asymmetric learning rate multipliers.
"""

from dataclasses import dataclass, field
from typing import List, Optional, Literal


# ============================================================================
# Fault Taxonomy
# ============================================================================
FAULT_CLASSES = [
    "normal_operation",     # 0: Normal PC running
    "boot_failure",         # 1: BIOS beep codes, POST failures
    "overheating_fan",      # 2: Fan noise anomalies, thermal warnings
    "storage_failure",      # 3: HDD clicking, drive errors
    "system_crash",         # 4: BSOD, system halts
]
NUM_CLASSES = len(FAULT_CLASSES)


@dataclass
class DataConfig:
    sample_rate: int = 32000
    audio_duration: float = 5.0
    n_fft: int = 1024
    hop_length: int = 320
    n_mels: int = 64
    fmin: int = 50
    fmax: int = 14000
    image_size: int = 224
    audio_noise_snr_db: float = 10.0
    time_shift_max: float = 0.2
    freq_mask_max: int = 10
    time_mask_max: int = 20


@dataclass
class ModelConfig:
    vit_model_name: str = "google/vit-base-patch16-224-in21k"
    vit_embed_dim: int = 768
    ast_model_name: str = "MIT/ast-finetuned-audioset-10-10-0.4593"
    ast_embed_dim: int = 768
    fusion_type: Literal["concat", "weighted_sum", "attention"] = "concat"
    fusion_dim: int = 512
    fusion_dropout: float = 0.3
    num_classes: int = NUM_CLASSES
    modality_dropout_p: float = 0.3


@dataclass
class LoRAConfig:
    enabled: bool = True
    r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.1
    bias: str = "none"
    vit_target_modules: List[str] = field(default_factory=lambda: ["query", "value"])
    vit_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"])
    ast_target_modules: List[str] = field(default_factory=lambda: ["query", "value"])
    ast_modules_to_save: List[str] = field(default_factory=lambda: ["classifier"])


@dataclass
class TrainConfig:
    mode: Literal["multimodal", "visual_only", "audio_only"] = "multimodal"
    finetune_method: Literal["lora", "full", "linear_probe"] = "lora"
    learning_rate: float = 5e-4
    lora_learning_rate: float = 5e-3
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1
    max_grad_norm: float = 1.0
    num_epochs: int = 15
    per_device_train_batch_size: int = 16
    per_device_eval_batch_size: int = 32
    gradient_accumulation_steps: int = 2
    fp16: bool = True
    eval_strategy: str = "epoch"
    metric_for_best_model: str = "macro_f1"
    output_dir: str = "./results"
    push_to_hub: bool = True
    hub_model_id: str = "Ellaft/multimodal-pc-fault-detector"
    save_strategy: str = "epoch"
    save_total_limit: int = 3
    load_best_model_at_end: bool = True
    logging_steps: int = 10
    logging_strategy: str = "steps"
    logging_first_step: bool = True
    disable_tqdm: bool = True
    seed: int = 42


@dataclass
class ExperimentConfig:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    lora: LoRAConfig = field(default_factory=LoRAConfig)
    train: TrainConfig = field(default_factory=TrainConfig)
    experiment_name: str = "multimodal_pc_fault"
    description: str = "Two-branch audio-visual fusion for PC fault detection with OGM-GE"

    # Anti-modality-collapse hyperparameters
    # OGM-GE: Peng et al., CVPR 2022 (arXiv: 2203.15332)
    ogm_alpha: float = 0.3
    ogm_noise_sigma: float = 0.1
    lambda_visual: float = 1.5      # Boost weak visual branch
    lambda_audio: float = 0.5       # Dampen dominant audio branch
    visual_lr_multiplier: float = 3.0
    audio_lr_multiplier: float = 0.5

    # Dataset
    hub_dataset: str = "Ellaft/pc-fault-real-dataset"


def get_ablation_configs():
    configs = {}

    cfg = ExperimentConfig()
    cfg.train.mode = "multimodal"
    cfg.experiment_name = "multimodal_lora"
    configs["multimodal_lora"] = cfg

    cfg = ExperimentConfig()
    cfg.train.mode = "visual_only"
    cfg.experiment_name = "visual_only_lora"
    cfg.train.hub_model_id = "Ellaft/pc-fault-visual-only"
    configs["visual_only_lora"] = cfg

    cfg = ExperimentConfig()
    cfg.train.mode = "audio_only"
    cfg.experiment_name = "audio_only_lora"
    cfg.train.hub_model_id = "Ellaft/pc-fault-audio-only"
    configs["audio_only_lora"] = cfg

    cfg = ExperimentConfig()
    cfg.train.mode = "multimodal"
    cfg.train.finetune_method = "full"
    cfg.lora.enabled = False
    cfg.train.learning_rate = 2e-5
    cfg.experiment_name = "multimodal_full_ft"
    cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-full-ft"
    configs["multimodal_full_ft"] = cfg

    cfg = ExperimentConfig()
    cfg.train.mode = "multimodal"
    cfg.train.finetune_method = "linear_probe"
    cfg.lora.enabled = False
    cfg.train.learning_rate = 1e-3
    cfg.experiment_name = "multimodal_linear_probe"
    cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-linear-probe"
    configs["multimodal_linear_probe"] = cfg

    cfg = ExperimentConfig()
    cfg.train.mode = "multimodal"
    cfg.model.modality_dropout_p = 0.5
    cfg.experiment_name = "multimodal_robust"
    cfg.train.hub_model_id = "Ellaft/pc-fault-multimodal-robust"
    configs["multimodal_robust"] = cfg

    return configs