""" Multimodal PC Fault Detection - Flexible Dataset Module ======================================================== Supports multiple data source modes: Audio: "esc50" (default) | "custom" (any HF audio dataset) Visual: "synthetic" (default) | "custom" (any HF image dataset) See examples at bottom of file or in README. """ import os, random import numpy as np import torch from torch.utils.data import Dataset from PIL import Image, ImageDraw from typing import Dict, Optional import torchaudio.transforms as T from datasets import load_dataset from config import (FAULT_CLASSES, ESC50_TO_FAULT, VISUAL_SYNTHESIS, DataConfig, ModelConfig) class SyntheticVisualGenerator: def __init__(self, image_size=224): self.image_size, self.templates = image_size, VISUAL_SYNTHESIS def generate(self, fault_class_name, variation_seed=0): rng = random.Random(variation_seed) t = self.templates[fault_class_name] br, bg, bb = t["color_dominant"] r, g, b = [max(0, min(255, c + rng.randint(-30, 30))) for c in (br, bg, bb)] img = Image.new("RGB", (self.image_size, self.image_size), (r, g, b)) draw = ImageDraw.Draw(img) y = rng.randint(10, 40) for text in t["text_overlay"]: tc = (255,255,255) if (r+g+b)/3 < 128 else (0,0,0) draw.text((rng.randint(5, max(10, self.image_size//4)), y), text, fill=tc) y += rng.randint(25, 45) arr = np.array(img) noise = np.random.RandomState(variation_seed).randint(-15, 16, arr.shape, dtype=np.int16) return Image.fromarray(np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8)) class AudioPreprocessor: def __init__(self, config): self.config, self.target_length = config, int(config.sample_rate * config.audio_duration) self.mel = T.MelSpectrogram(sample_rate=config.sample_rate, n_fft=config.n_fft, hop_length=config.hop_length, n_mels=config.n_mels, f_min=config.fmin, f_max=config.fmax) self.to_db = T.AmplitudeToDB(top_db=80) def process(self, waveform, sr): if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(0, keepdim=True) elif waveform.dim() == 1: waveform = waveform.unsqueeze(0) if sr != self.config.sample_rate: waveform = T.Resample(sr, self.config.sample_rate)(waveform) if waveform.shape[-1] < self.target_length: waveform = torch.nn.functional.pad(waveform, (0, self.target_length - waveform.shape[-1])) else: waveform = waveform[:, :self.target_length] return self.to_db(self.mel(waveform)) def augment(self, x, training=True): if not training: return x return T.TimeMasking(self.config.time_mask_max)(T.FrequencyMasking(self.config.freq_mask_max)(x)) class PCFaultDataset(Dataset): """ Flexible multimodal dataset. Works with any HF audio + image dataset. AUDIO: set audio_source="custom" and provide: custom_audio_dataset, custom_audio_column, custom_audio_label_column, custom_audio_label_map VISUAL: set visual_source="custom" and provide: custom_visual_dataset, custom_visual_column, custom_visual_label_column, custom_visual_label_map """ def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True, audio_source="esc50", custom_audio_dataset=None, custom_audio_config=None, custom_audio_split=None, custom_audio_column="audio", custom_audio_label_column="label", custom_audio_label_map=None, visual_source="synthetic", custom_visual_dataset=None, custom_visual_config=None, custom_visual_split=None, custom_visual_column="image", custom_visual_label_column="label", custom_visual_label_map=None): self.config, self.split, self.augment = config, split, augment and (split == "train") self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor self.audio_source, self.visual_source = audio_source, visual_source self.audio_preprocessor = AudioPreprocessor(config) self.visual_gen = SyntheticVisualGenerator(config.image_size) self._audio_column = custom_audio_column # Load audio if audio_source == "esc50": self.audio_data = load_dataset(config.esc50_dataset, split="train") self.samples = [] for idx, row in enumerate(self.audio_data): cat = row["category"] if cat in ESC50_TO_FAULT: fl, fold = ESC50_TO_FAULT[cat], row["fold"] if (split=="train" and fold!=config.val_fold) or (split=="val" and fold==config.val_fold): self.samples.append({"audio_idx": idx, "fault_label": fl}) elif audio_source == "custom": assert custom_audio_dataset, "custom_audio_dataset required" assert custom_audio_label_map, "custom_audio_label_map required: {original_label: 0-4}" kw = {"name": custom_audio_config} if custom_audio_config else {} self.audio_data = load_dataset(custom_audio_dataset, split=custom_audio_split or split, **kw) self.samples = [] for idx, row in enumerate(self.audio_data): lbl = row[custom_audio_label_column] mapped = custom_audio_label_map.get(lbl, custom_audio_label_map.get(str(lbl))) if mapped is not None: self.samples.append({"audio_idx": idx, "fault_label": mapped}) # Load visual self.visual_data, self._visual_by_class = None, {i: [] for i in range(len(FAULT_CLASSES))} if visual_source == "custom" and custom_visual_dataset: kw = {"name": custom_visual_config} if custom_visual_config else {} self.visual_data = load_dataset(custom_visual_dataset, split=custom_visual_split or split, **kw) self._visual_column = custom_visual_column for idx, row in enumerate(self.visual_data): lbl = row.get(custom_visual_label_column) mapped = (custom_visual_label_map or {}).get(lbl, (custom_visual_label_map or {}).get(str(lbl))) if mapped is not None: self._visual_by_class[mapped].append(idx) print(f"[Visual] Custom: {custom_visual_dataset}, per-class: { {k:len(v) for k,v in self._visual_by_class.items()} }") # Stats lc = {} for s in self.samples: lc[s["fault_label"]] = lc.get(s["fault_label"], 0) + 1 print(f"[Dataset] {split} | audio={audio_source} | visual={visual_source} | total={len(self.samples)}") for l, c in sorted(lc.items()): print(f" {FAULT_CLASSES[l]}: {c}") def __len__(self): return len(self.samples) def __getitem__(self, idx): s = self.samples[idx] fl, fn = s["fault_label"], FAULT_CLASSES[s["fault_label"]] # Audio row = self.audio_data[s["audio_idx"]] ad = row["audio"] if self.audio_source == "esc50" else row[self._audio_column] if isinstance(ad, dict) and "array" in ad: arr, sr = np.array(ad["array"], dtype=np.float32), ad["sampling_rate"] else: arr, sr = np.array(ad, dtype=np.float32), self.config.sample_rate if self.ast_feature_extractor: if sr != 16000: arr = T.Resample(sr, 16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy() audio_values = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0) else: lm = self.audio_preprocessor.process(torch.tensor(arr, dtype=torch.float32), sr) if self.augment: lm = self.audio_preprocessor.augment(lm) audio_values = lm.squeeze(0) # Visual if self.visual_data and self._visual_by_class.get(fl): ci = self._visual_by_class[fl] img = self.visual_data[ci[idx % len(ci)]][self._visual_column] if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img)) else: img = self.visual_gen.generate(fn, variation_seed=idx * 7 + fl) if img.mode != "RGB": img = img.convert("RGB") if self.vit_processor: pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0) else: a = np.array(img.resize((224,224))).astype(np.float32)/255.0 pv = torch.tensor((a - [.485,.456,.406])/[.229,.224,.225]).permute(2,0,1) return {"pixel_values": pv, "audio_values": audio_values, "labels": torch.tensor(fl, dtype=torch.long)} def multimodal_collate_fn(batch): pv = torch.stack([b["pixel_values"] for b in batch]) labels = torch.stack([b["labels"] for b in batch]) al = [b["audio_values"] for b in batch] ml = max(a.shape[-1] for a in al) return {"pixel_values": pv, "labels": labels, "audio_values": torch.stack([torch.nn.functional.pad(a, (0, ml-a.shape[-1])) if a.shape[-1]