multimodal-pc-fault-detector / src /dataset_flexible.py
Ellaft's picture
Add flexible dataset module supporting custom HF audio and image datasets
73d65f1 verified
"""
Multimodal PC Fault Detection - Flexible Dataset Module
========================================================
Supports multiple data source modes:
Audio: "esc50" (default) | "custom" (any HF audio dataset)
Visual: "synthetic" (default) | "custom" (any HF image dataset)
See examples at bottom of file or in README.
"""
import os, random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageDraw
from typing import Dict, Optional
import torchaudio.transforms as T
from datasets import load_dataset
from config import (FAULT_CLASSES, ESC50_TO_FAULT, VISUAL_SYNTHESIS, DataConfig, ModelConfig)
class SyntheticVisualGenerator:
def __init__(self, image_size=224):
self.image_size, self.templates = image_size, VISUAL_SYNTHESIS
def generate(self, fault_class_name, variation_seed=0):
rng = random.Random(variation_seed)
t = self.templates[fault_class_name]
br, bg, bb = t["color_dominant"]
r, g, b = [max(0, min(255, c + rng.randint(-30, 30))) for c in (br, bg, bb)]
img = Image.new("RGB", (self.image_size, self.image_size), (r, g, b))
draw = ImageDraw.Draw(img)
y = rng.randint(10, 40)
for text in t["text_overlay"]:
tc = (255,255,255) if (r+g+b)/3 < 128 else (0,0,0)
draw.text((rng.randint(5, max(10, self.image_size//4)), y), text, fill=tc)
y += rng.randint(25, 45)
arr = np.array(img)
noise = np.random.RandomState(variation_seed).randint(-15, 16, arr.shape, dtype=np.int16)
return Image.fromarray(np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8))
class AudioPreprocessor:
def __init__(self, config):
self.config, self.target_length = config, int(config.sample_rate * config.audio_duration)
self.mel = T.MelSpectrogram(sample_rate=config.sample_rate, n_fft=config.n_fft,
hop_length=config.hop_length, n_mels=config.n_mels, f_min=config.fmin, f_max=config.fmax)
self.to_db = T.AmplitudeToDB(top_db=80)
def process(self, waveform, sr):
if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(0, keepdim=True)
elif waveform.dim() == 1: waveform = waveform.unsqueeze(0)
if sr != self.config.sample_rate: waveform = T.Resample(sr, self.config.sample_rate)(waveform)
if waveform.shape[-1] < self.target_length:
waveform = torch.nn.functional.pad(waveform, (0, self.target_length - waveform.shape[-1]))
else: waveform = waveform[:, :self.target_length]
return self.to_db(self.mel(waveform))
def augment(self, x, training=True):
if not training: return x
return T.TimeMasking(self.config.time_mask_max)(T.FrequencyMasking(self.config.freq_mask_max)(x))
class PCFaultDataset(Dataset):
"""
Flexible multimodal dataset. Works with any HF audio + image dataset.
AUDIO: set audio_source="custom" and provide:
custom_audio_dataset, custom_audio_column, custom_audio_label_column, custom_audio_label_map
VISUAL: set visual_source="custom" and provide:
custom_visual_dataset, custom_visual_column, custom_visual_label_column, custom_visual_label_map
"""
def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True,
audio_source="esc50", custom_audio_dataset=None, custom_audio_config=None, custom_audio_split=None,
custom_audio_column="audio", custom_audio_label_column="label", custom_audio_label_map=None,
visual_source="synthetic", custom_visual_dataset=None, custom_visual_config=None, custom_visual_split=None,
custom_visual_column="image", custom_visual_label_column="label", custom_visual_label_map=None):
self.config, self.split, self.augment = config, split, augment and (split == "train")
self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor
self.audio_source, self.visual_source = audio_source, visual_source
self.audio_preprocessor = AudioPreprocessor(config)
self.visual_gen = SyntheticVisualGenerator(config.image_size)
self._audio_column = custom_audio_column
# Load audio
if audio_source == "esc50":
self.audio_data = load_dataset(config.esc50_dataset, split="train")
self.samples = []
for idx, row in enumerate(self.audio_data):
cat = row["category"]
if cat in ESC50_TO_FAULT:
fl, fold = ESC50_TO_FAULT[cat], row["fold"]
if (split=="train" and fold!=config.val_fold) or (split=="val" and fold==config.val_fold):
self.samples.append({"audio_idx": idx, "fault_label": fl})
elif audio_source == "custom":
assert custom_audio_dataset, "custom_audio_dataset required"
assert custom_audio_label_map, "custom_audio_label_map required: {original_label: 0-4}"
kw = {"name": custom_audio_config} if custom_audio_config else {}
self.audio_data = load_dataset(custom_audio_dataset, split=custom_audio_split or split, **kw)
self.samples = []
for idx, row in enumerate(self.audio_data):
lbl = row[custom_audio_label_column]
mapped = custom_audio_label_map.get(lbl, custom_audio_label_map.get(str(lbl)))
if mapped is not None:
self.samples.append({"audio_idx": idx, "fault_label": mapped})
# Load visual
self.visual_data, self._visual_by_class = None, {i: [] for i in range(len(FAULT_CLASSES))}
if visual_source == "custom" and custom_visual_dataset:
kw = {"name": custom_visual_config} if custom_visual_config else {}
self.visual_data = load_dataset(custom_visual_dataset, split=custom_visual_split or split, **kw)
self._visual_column = custom_visual_column
for idx, row in enumerate(self.visual_data):
lbl = row.get(custom_visual_label_column)
mapped = (custom_visual_label_map or {}).get(lbl, (custom_visual_label_map or {}).get(str(lbl)))
if mapped is not None: self._visual_by_class[mapped].append(idx)
print(f"[Visual] Custom: {custom_visual_dataset}, per-class: { {k:len(v) for k,v in self._visual_by_class.items()} }")
# Stats
lc = {}
for s in self.samples: lc[s["fault_label"]] = lc.get(s["fault_label"], 0) + 1
print(f"[Dataset] {split} | audio={audio_source} | visual={visual_source} | total={len(self.samples)}")
for l, c in sorted(lc.items()): print(f" {FAULT_CLASSES[l]}: {c}")
def __len__(self): return len(self.samples)
def __getitem__(self, idx):
s = self.samples[idx]
fl, fn = s["fault_label"], FAULT_CLASSES[s["fault_label"]]
# Audio
row = self.audio_data[s["audio_idx"]]
ad = row["audio"] if self.audio_source == "esc50" else row[self._audio_column]
if isinstance(ad, dict) and "array" in ad:
arr, sr = np.array(ad["array"], dtype=np.float32), ad["sampling_rate"]
else:
arr, sr = np.array(ad, dtype=np.float32), self.config.sample_rate
if self.ast_feature_extractor:
if sr != 16000: arr = T.Resample(sr, 16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy()
audio_values = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0)
else:
lm = self.audio_preprocessor.process(torch.tensor(arr, dtype=torch.float32), sr)
if self.augment: lm = self.audio_preprocessor.augment(lm)
audio_values = lm.squeeze(0)
# Visual
if self.visual_data and self._visual_by_class.get(fl):
ci = self._visual_by_class[fl]
img = self.visual_data[ci[idx % len(ci)]][self._visual_column]
if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img))
else:
img = self.visual_gen.generate(fn, variation_seed=idx * 7 + fl)
if img.mode != "RGB": img = img.convert("RGB")
if self.vit_processor:
pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
else:
a = np.array(img.resize((224,224))).astype(np.float32)/255.0
pv = torch.tensor((a - [.485,.456,.406])/[.229,.224,.225]).permute(2,0,1)
return {"pixel_values": pv, "audio_values": audio_values, "labels": torch.tensor(fl, dtype=torch.long)}
def multimodal_collate_fn(batch):
pv = torch.stack([b["pixel_values"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
al = [b["audio_values"] for b in batch]
ml = max(a.shape[-1] for a in al)
return {"pixel_values": pv, "labels": labels,
"audio_values": torch.stack([torch.nn.functional.pad(a, (0, ml-a.shape[-1])) if a.shape[-1]<ml else a for a in al])}