| """ |
| Multimodal PC Fault Detection - Flexible Dataset Module |
| ======================================================== |
| Supports multiple data source modes: |
| Audio: "esc50" (default) | "custom" (any HF audio dataset) |
| Visual: "synthetic" (default) | "custom" (any HF image dataset) |
| |
| See examples at bottom of file or in README. |
| """ |
|
|
| import os, random |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| from PIL import Image, ImageDraw |
| from typing import Dict, Optional |
| import torchaudio.transforms as T |
| from datasets import load_dataset |
| from config import (FAULT_CLASSES, ESC50_TO_FAULT, VISUAL_SYNTHESIS, DataConfig, ModelConfig) |
|
|
|
|
| class SyntheticVisualGenerator: |
| def __init__(self, image_size=224): |
| self.image_size, self.templates = image_size, VISUAL_SYNTHESIS |
| def generate(self, fault_class_name, variation_seed=0): |
| rng = random.Random(variation_seed) |
| t = self.templates[fault_class_name] |
| br, bg, bb = t["color_dominant"] |
| r, g, b = [max(0, min(255, c + rng.randint(-30, 30))) for c in (br, bg, bb)] |
| img = Image.new("RGB", (self.image_size, self.image_size), (r, g, b)) |
| draw = ImageDraw.Draw(img) |
| y = rng.randint(10, 40) |
| for text in t["text_overlay"]: |
| tc = (255,255,255) if (r+g+b)/3 < 128 else (0,0,0) |
| draw.text((rng.randint(5, max(10, self.image_size//4)), y), text, fill=tc) |
| y += rng.randint(25, 45) |
| arr = np.array(img) |
| noise = np.random.RandomState(variation_seed).randint(-15, 16, arr.shape, dtype=np.int16) |
| return Image.fromarray(np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8)) |
|
|
|
|
| class AudioPreprocessor: |
| def __init__(self, config): |
| self.config, self.target_length = config, int(config.sample_rate * config.audio_duration) |
| self.mel = T.MelSpectrogram(sample_rate=config.sample_rate, n_fft=config.n_fft, |
| hop_length=config.hop_length, n_mels=config.n_mels, f_min=config.fmin, f_max=config.fmax) |
| self.to_db = T.AmplitudeToDB(top_db=80) |
| def process(self, waveform, sr): |
| if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(0, keepdim=True) |
| elif waveform.dim() == 1: waveform = waveform.unsqueeze(0) |
| if sr != self.config.sample_rate: waveform = T.Resample(sr, self.config.sample_rate)(waveform) |
| if waveform.shape[-1] < self.target_length: |
| waveform = torch.nn.functional.pad(waveform, (0, self.target_length - waveform.shape[-1])) |
| else: waveform = waveform[:, :self.target_length] |
| return self.to_db(self.mel(waveform)) |
| def augment(self, x, training=True): |
| if not training: return x |
| return T.TimeMasking(self.config.time_mask_max)(T.FrequencyMasking(self.config.freq_mask_max)(x)) |
|
|
|
|
| class PCFaultDataset(Dataset): |
| """ |
| Flexible multimodal dataset. Works with any HF audio + image dataset. |
| |
| AUDIO: set audio_source="custom" and provide: |
| custom_audio_dataset, custom_audio_column, custom_audio_label_column, custom_audio_label_map |
| |
| VISUAL: set visual_source="custom" and provide: |
| custom_visual_dataset, custom_visual_column, custom_visual_label_column, custom_visual_label_map |
| """ |
| def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True, |
| audio_source="esc50", custom_audio_dataset=None, custom_audio_config=None, custom_audio_split=None, |
| custom_audio_column="audio", custom_audio_label_column="label", custom_audio_label_map=None, |
| visual_source="synthetic", custom_visual_dataset=None, custom_visual_config=None, custom_visual_split=None, |
| custom_visual_column="image", custom_visual_label_column="label", custom_visual_label_map=None): |
| self.config, self.split, self.augment = config, split, augment and (split == "train") |
| self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor |
| self.audio_source, self.visual_source = audio_source, visual_source |
| self.audio_preprocessor = AudioPreprocessor(config) |
| self.visual_gen = SyntheticVisualGenerator(config.image_size) |
| self._audio_column = custom_audio_column |
| |
| |
| if audio_source == "esc50": |
| self.audio_data = load_dataset(config.esc50_dataset, split="train") |
| self.samples = [] |
| for idx, row in enumerate(self.audio_data): |
| cat = row["category"] |
| if cat in ESC50_TO_FAULT: |
| fl, fold = ESC50_TO_FAULT[cat], row["fold"] |
| if (split=="train" and fold!=config.val_fold) or (split=="val" and fold==config.val_fold): |
| self.samples.append({"audio_idx": idx, "fault_label": fl}) |
| elif audio_source == "custom": |
| assert custom_audio_dataset, "custom_audio_dataset required" |
| assert custom_audio_label_map, "custom_audio_label_map required: {original_label: 0-4}" |
| kw = {"name": custom_audio_config} if custom_audio_config else {} |
| self.audio_data = load_dataset(custom_audio_dataset, split=custom_audio_split or split, **kw) |
| self.samples = [] |
| for idx, row in enumerate(self.audio_data): |
| lbl = row[custom_audio_label_column] |
| mapped = custom_audio_label_map.get(lbl, custom_audio_label_map.get(str(lbl))) |
| if mapped is not None: |
| self.samples.append({"audio_idx": idx, "fault_label": mapped}) |
| |
| |
| self.visual_data, self._visual_by_class = None, {i: [] for i in range(len(FAULT_CLASSES))} |
| if visual_source == "custom" and custom_visual_dataset: |
| kw = {"name": custom_visual_config} if custom_visual_config else {} |
| self.visual_data = load_dataset(custom_visual_dataset, split=custom_visual_split or split, **kw) |
| self._visual_column = custom_visual_column |
| for idx, row in enumerate(self.visual_data): |
| lbl = row.get(custom_visual_label_column) |
| mapped = (custom_visual_label_map or {}).get(lbl, (custom_visual_label_map or {}).get(str(lbl))) |
| if mapped is not None: self._visual_by_class[mapped].append(idx) |
| print(f"[Visual] Custom: {custom_visual_dataset}, per-class: { {k:len(v) for k,v in self._visual_by_class.items()} }") |
| |
| |
| lc = {} |
| for s in self.samples: lc[s["fault_label"]] = lc.get(s["fault_label"], 0) + 1 |
| print(f"[Dataset] {split} | audio={audio_source} | visual={visual_source} | total={len(self.samples)}") |
| for l, c in sorted(lc.items()): print(f" {FAULT_CLASSES[l]}: {c}") |
| |
| def __len__(self): return len(self.samples) |
| |
| def __getitem__(self, idx): |
| s = self.samples[idx] |
| fl, fn = s["fault_label"], FAULT_CLASSES[s["fault_label"]] |
| |
| |
| row = self.audio_data[s["audio_idx"]] |
| ad = row["audio"] if self.audio_source == "esc50" else row[self._audio_column] |
| if isinstance(ad, dict) and "array" in ad: |
| arr, sr = np.array(ad["array"], dtype=np.float32), ad["sampling_rate"] |
| else: |
| arr, sr = np.array(ad, dtype=np.float32), self.config.sample_rate |
| |
| if self.ast_feature_extractor: |
| if sr != 16000: arr = T.Resample(sr, 16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy() |
| audio_values = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0) |
| else: |
| lm = self.audio_preprocessor.process(torch.tensor(arr, dtype=torch.float32), sr) |
| if self.augment: lm = self.audio_preprocessor.augment(lm) |
| audio_values = lm.squeeze(0) |
| |
| |
| if self.visual_data and self._visual_by_class.get(fl): |
| ci = self._visual_by_class[fl] |
| img = self.visual_data[ci[idx % len(ci)]][self._visual_column] |
| if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img)) |
| else: |
| img = self.visual_gen.generate(fn, variation_seed=idx * 7 + fl) |
| if img.mode != "RGB": img = img.convert("RGB") |
| |
| if self.vit_processor: |
| pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0) |
| else: |
| a = np.array(img.resize((224,224))).astype(np.float32)/255.0 |
| pv = torch.tensor((a - [.485,.456,.406])/[.229,.224,.225]).permute(2,0,1) |
| |
| return {"pixel_values": pv, "audio_values": audio_values, "labels": torch.tensor(fl, dtype=torch.long)} |
|
|
|
|
| def multimodal_collate_fn(batch): |
| pv = torch.stack([b["pixel_values"] for b in batch]) |
| labels = torch.stack([b["labels"] for b in batch]) |
| al = [b["audio_values"] for b in batch] |
| ml = max(a.shape[-1] for a in al) |
| return {"pixel_values": pv, "labels": labels, |
| "audio_values": torch.stack([torch.nn.functional.pad(a, (0, ml-a.shape[-1])) if a.shape[-1]<ml else a for a in al])} |
|
|