| """ |
| Multimodal PC Fault Detection - Dataset Module |
| =============================================== |
| Handles audio/visual data loading, synthetic visual generation, |
| ESC-50 → fault mapping, and paired multimodal batching. |
| """ |
|
|
| import os |
| import random |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| from PIL import Image, ImageDraw, ImageFont |
| from typing import Dict, Optional, Tuple |
| import torchaudio |
| import torchaudio.transforms as T |
| from datasets import load_dataset |
|
|
| from config import ( |
| FAULT_CLASSES, ESC50_TO_FAULT, ESC50_CATEGORY_TO_TARGET, |
| VISUAL_SYNTHESIS, DataConfig, ModelConfig |
| ) |
|
|
|
|
| class SyntheticVisualGenerator: |
| """Generates synthetic PC diagnostic screen images for each fault class.""" |
| |
| def __init__(self, image_size: int = 224): |
| self.image_size = image_size |
| self.templates = VISUAL_SYNTHESIS |
| |
| def generate(self, fault_class_name: str, variation_seed: int = 0) -> Image.Image: |
| rng = random.Random(variation_seed) |
| template = self.templates[fault_class_name] |
| base_r, base_g, base_b = template["color_dominant"] |
| r = max(0, min(255, base_r + rng.randint(-30, 30))) |
| g = max(0, min(255, base_g + rng.randint(-30, 30))) |
| b = max(0, min(255, base_b + rng.randint(-30, 30))) |
| img = Image.new("RGB", (self.image_size, self.image_size), (r, g, b)) |
| draw = ImageDraw.Draw(img) |
| texts = template["text_overlay"] |
| y_offset = rng.randint(10, 40) |
| for text in texts: |
| text_color = (255, 255, 255) if (r + g + b) / 3 < 128 else (0, 0, 0) |
| x = rng.randint(5, max(10, self.image_size // 4)) |
| draw.text((x, y_offset), text, fill=text_color) |
| y_offset += rng.randint(25, 45) |
| if fault_class_name == "normal_operation": |
| bar_y = rng.randint(100, 180) |
| draw.rectangle([10, bar_y, int(self.image_size * rng.uniform(0.5, 0.95)), bar_y + 15], fill=(0, 200, 0)) |
| elif fault_class_name == "system_crash": |
| cx, cy = self.image_size // 2, self.image_size // 3 |
| draw.text((cx - 10, cy), ":(", fill=(255, 255, 255)) |
| elif fault_class_name == "overheating_fan": |
| gauge_x = rng.randint(40, 120) |
| draw.rectangle([gauge_x, 150, gauge_x + 30, 200], outline=(255, 255, 255)) |
| fill_h = int(50 * rng.uniform(0.7, 1.0)) |
| draw.rectangle([gauge_x + 2, 200 - fill_h, gauge_x + 28, 198], fill=(255, 50, 50)) |
| elif fault_class_name == "storage_failure": |
| bar_y = 160 |
| draw.rectangle([20, bar_y, 200, bar_y + 12], outline=(255, 255, 255)) |
| progress = rng.uniform(0.1, 0.4) |
| draw.rectangle([22, bar_y + 2, int(22 + 176 * progress), bar_y + 10], fill=(255, 165, 0)) |
| elif fault_class_name == "boot_failure": |
| cx, cy = rng.randint(10, 60), rng.randint(140, 200) |
| draw.rectangle([cx, cy, cx + 10, cy + 15], fill=(255, 255, 255)) |
| img_array = np.array(img) |
| noise = np.random.RandomState(variation_seed).randint(-15, 16, img_array.shape, dtype=np.int16) |
| img_array = np.clip(img_array.astype(np.int16) + noise, 0, 255).astype(np.uint8) |
| return Image.fromarray(img_array) |
|
|
|
|
| class AudioPreprocessor: |
| def __init__(self, config: DataConfig): |
| self.config = config |
| self.target_length = int(config.sample_rate * config.audio_duration) |
| self.mel_transform = T.MelSpectrogram( |
| sample_rate=config.sample_rate, n_fft=config.n_fft, hop_length=config.hop_length, |
| n_mels=config.n_mels, f_min=config.fmin, f_max=config.fmax, |
| ) |
| self.amplitude_to_db = T.AmplitudeToDB(top_db=80) |
| |
| def process(self, waveform, sample_rate): |
| if waveform.dim() > 1 and waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| elif waveform.dim() == 1: |
| waveform = waveform.unsqueeze(0) |
| if sample_rate != self.config.sample_rate: |
| waveform = T.Resample(orig_freq=sample_rate, new_freq=self.config.sample_rate)(waveform) |
| if waveform.shape[-1] < self.target_length: |
| waveform = torch.nn.functional.pad(waveform, (0, self.target_length - waveform.shape[-1])) |
| else: |
| waveform = waveform[:, :self.target_length] |
| return self.amplitude_to_db(self.mel_transform(waveform)) |
| |
| def augment(self, log_mel, training=True): |
| if not training: |
| return log_mel |
| log_mel = T.FrequencyMasking(freq_mask_param=self.config.freq_mask_max)(log_mel) |
| return T.TimeMasking(time_mask_param=self.config.time_mask_max)(log_mel) |
|
|
|
|
| class PCFaultDataset(Dataset): |
| def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True): |
| self.config, self.model_config, self.split = config, model_config, split |
| self.augment = augment and (split == "train") |
| self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor |
| self.esc50 = load_dataset(config.esc50_dataset, split="train") |
| self.samples = [] |
| for idx, row in enumerate(self.esc50): |
| category = row["category"] |
| if category in ESC50_TO_FAULT: |
| fault_label, fold = ESC50_TO_FAULT[category], row["fold"] |
| if (split == "train" and fold != config.val_fold) or (split == "val" and fold == config.val_fold): |
| self.samples.append({"audio_idx": idx, "fault_label": fault_label, "category": category}) |
| self.visual_gen = SyntheticVisualGenerator(config.image_size) |
| self.audio_preprocessor = AudioPreprocessor(config) |
| label_counts = {} |
| for s in self.samples: |
| label_counts[s["fault_label"]] = label_counts.get(s["fault_label"], 0) + 1 |
| print(f"\\n[PCFaultDataset] Split: {split}, Total: {len(self.samples)}") |
| for label, count in sorted(label_counts.items()): |
| print(f" Class {label} ({FAULT_CLASSES[label]}): {count}") |
| |
| def __len__(self): |
| return len(self.samples) |
| |
| def __getitem__(self, idx): |
| sample = self.samples[idx] |
| fault_label, fault_name = sample["fault_label"], FAULT_CLASSES[sample["fault_label"]] |
| esc50_row = self.esc50[sample["audio_idx"]] |
| audio_data = esc50_row["audio"] |
| if self.ast_feature_extractor is not None: |
| audio_array = np.array(audio_data["array"], dtype=np.float32) |
| sr = audio_data["sampling_rate"] |
| if sr != 16000: |
| wf = T.Resample(orig_freq=sr, new_freq=16000)(torch.tensor(audio_array).unsqueeze(0)) |
| audio_array = wf.squeeze(0).numpy() |
| audio_values = self.ast_feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0) |
| else: |
| waveform = torch.tensor(audio_data["array"], dtype=torch.float32) |
| log_mel = self.audio_preprocessor.process(waveform, audio_data["sampling_rate"]) |
| if self.augment: |
| log_mel = self.audio_preprocessor.augment(log_mel, training=True) |
| audio_values = log_mel.squeeze(0) |
| visual_image = self.visual_gen.generate(fault_name, variation_seed=idx * 7 + fault_label) |
| if self.vit_processor is not None: |
| pixel_values = self.vit_processor(images=visual_image, return_tensors="pt")["pixel_values"].squeeze(0) |
| else: |
| img_array = np.array(visual_image.resize((224, 224))).astype(np.float32) / 255.0 |
| img_array = (img_array - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] |
| pixel_values = torch.tensor(img_array).permute(2, 0, 1) |
| return {"pixel_values": pixel_values, "audio_values": audio_values, "labels": torch.tensor(fault_label, dtype=torch.long)} |
|
|
|
|
| def multimodal_collate_fn(batch): |
| pixel_values = torch.stack([b["pixel_values"] for b in batch]) |
| labels = torch.stack([b["labels"] for b in batch]) |
| audio_list = [b["audio_values"] for b in batch] |
| max_len = max(a.shape[-1] for a in audio_list) |
| padded_audio = [torch.nn.functional.pad(a, (0, max_len - a.shape[-1])) if a.shape[-1] < max_len else a for a in audio_list] |
| return {"pixel_values": pixel_values, "audio_values": torch.stack(padded_audio), "labels": labels} |
|
|