File size: 9,134 Bytes
73d65f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
Multimodal PC Fault Detection - Flexible Dataset Module
========================================================
Supports multiple data source modes:
Audio: "esc50" (default) | "custom" (any HF audio dataset)
Visual: "synthetic" (default) | "custom" (any HF image dataset)
See examples at bottom of file or in README.
"""
import os, random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageDraw
from typing import Dict, Optional
import torchaudio.transforms as T
from datasets import load_dataset
from config import (FAULT_CLASSES, ESC50_TO_FAULT, VISUAL_SYNTHESIS, DataConfig, ModelConfig)
class SyntheticVisualGenerator:
def __init__(self, image_size=224):
self.image_size, self.templates = image_size, VISUAL_SYNTHESIS
def generate(self, fault_class_name, variation_seed=0):
rng = random.Random(variation_seed)
t = self.templates[fault_class_name]
br, bg, bb = t["color_dominant"]
r, g, b = [max(0, min(255, c + rng.randint(-30, 30))) for c in (br, bg, bb)]
img = Image.new("RGB", (self.image_size, self.image_size), (r, g, b))
draw = ImageDraw.Draw(img)
y = rng.randint(10, 40)
for text in t["text_overlay"]:
tc = (255,255,255) if (r+g+b)/3 < 128 else (0,0,0)
draw.text((rng.randint(5, max(10, self.image_size//4)), y), text, fill=tc)
y += rng.randint(25, 45)
arr = np.array(img)
noise = np.random.RandomState(variation_seed).randint(-15, 16, arr.shape, dtype=np.int16)
return Image.fromarray(np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8))
class AudioPreprocessor:
def __init__(self, config):
self.config, self.target_length = config, int(config.sample_rate * config.audio_duration)
self.mel = T.MelSpectrogram(sample_rate=config.sample_rate, n_fft=config.n_fft,
hop_length=config.hop_length, n_mels=config.n_mels, f_min=config.fmin, f_max=config.fmax)
self.to_db = T.AmplitudeToDB(top_db=80)
def process(self, waveform, sr):
if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(0, keepdim=True)
elif waveform.dim() == 1: waveform = waveform.unsqueeze(0)
if sr != self.config.sample_rate: waveform = T.Resample(sr, self.config.sample_rate)(waveform)
if waveform.shape[-1] < self.target_length:
waveform = torch.nn.functional.pad(waveform, (0, self.target_length - waveform.shape[-1]))
else: waveform = waveform[:, :self.target_length]
return self.to_db(self.mel(waveform))
def augment(self, x, training=True):
if not training: return x
return T.TimeMasking(self.config.time_mask_max)(T.FrequencyMasking(self.config.freq_mask_max)(x))
class PCFaultDataset(Dataset):
"""
Flexible multimodal dataset. Works with any HF audio + image dataset.
AUDIO: set audio_source="custom" and provide:
custom_audio_dataset, custom_audio_column, custom_audio_label_column, custom_audio_label_map
VISUAL: set visual_source="custom" and provide:
custom_visual_dataset, custom_visual_column, custom_visual_label_column, custom_visual_label_map
"""
def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True,
audio_source="esc50", custom_audio_dataset=None, custom_audio_config=None, custom_audio_split=None,
custom_audio_column="audio", custom_audio_label_column="label", custom_audio_label_map=None,
visual_source="synthetic", custom_visual_dataset=None, custom_visual_config=None, custom_visual_split=None,
custom_visual_column="image", custom_visual_label_column="label", custom_visual_label_map=None):
self.config, self.split, self.augment = config, split, augment and (split == "train")
self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor
self.audio_source, self.visual_source = audio_source, visual_source
self.audio_preprocessor = AudioPreprocessor(config)
self.visual_gen = SyntheticVisualGenerator(config.image_size)
self._audio_column = custom_audio_column
# Load audio
if audio_source == "esc50":
self.audio_data = load_dataset(config.esc50_dataset, split="train")
self.samples = []
for idx, row in enumerate(self.audio_data):
cat = row["category"]
if cat in ESC50_TO_FAULT:
fl, fold = ESC50_TO_FAULT[cat], row["fold"]
if (split=="train" and fold!=config.val_fold) or (split=="val" and fold==config.val_fold):
self.samples.append({"audio_idx": idx, "fault_label": fl})
elif audio_source == "custom":
assert custom_audio_dataset, "custom_audio_dataset required"
assert custom_audio_label_map, "custom_audio_label_map required: {original_label: 0-4}"
kw = {"name": custom_audio_config} if custom_audio_config else {}
self.audio_data = load_dataset(custom_audio_dataset, split=custom_audio_split or split, **kw)
self.samples = []
for idx, row in enumerate(self.audio_data):
lbl = row[custom_audio_label_column]
mapped = custom_audio_label_map.get(lbl, custom_audio_label_map.get(str(lbl)))
if mapped is not None:
self.samples.append({"audio_idx": idx, "fault_label": mapped})
# Load visual
self.visual_data, self._visual_by_class = None, {i: [] for i in range(len(FAULT_CLASSES))}
if visual_source == "custom" and custom_visual_dataset:
kw = {"name": custom_visual_config} if custom_visual_config else {}
self.visual_data = load_dataset(custom_visual_dataset, split=custom_visual_split or split, **kw)
self._visual_column = custom_visual_column
for idx, row in enumerate(self.visual_data):
lbl = row.get(custom_visual_label_column)
mapped = (custom_visual_label_map or {}).get(lbl, (custom_visual_label_map or {}).get(str(lbl)))
if mapped is not None: self._visual_by_class[mapped].append(idx)
print(f"[Visual] Custom: {custom_visual_dataset}, per-class: { {k:len(v) for k,v in self._visual_by_class.items()} }")
# Stats
lc = {}
for s in self.samples: lc[s["fault_label"]] = lc.get(s["fault_label"], 0) + 1
print(f"[Dataset] {split} | audio={audio_source} | visual={visual_source} | total={len(self.samples)}")
for l, c in sorted(lc.items()): print(f" {FAULT_CLASSES[l]}: {c}")
def __len__(self): return len(self.samples)
def __getitem__(self, idx):
s = self.samples[idx]
fl, fn = s["fault_label"], FAULT_CLASSES[s["fault_label"]]
# Audio
row = self.audio_data[s["audio_idx"]]
ad = row["audio"] if self.audio_source == "esc50" else row[self._audio_column]
if isinstance(ad, dict) and "array" in ad:
arr, sr = np.array(ad["array"], dtype=np.float32), ad["sampling_rate"]
else:
arr, sr = np.array(ad, dtype=np.float32), self.config.sample_rate
if self.ast_feature_extractor:
if sr != 16000: arr = T.Resample(sr, 16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy()
audio_values = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0)
else:
lm = self.audio_preprocessor.process(torch.tensor(arr, dtype=torch.float32), sr)
if self.augment: lm = self.audio_preprocessor.augment(lm)
audio_values = lm.squeeze(0)
# Visual
if self.visual_data and self._visual_by_class.get(fl):
ci = self._visual_by_class[fl]
img = self.visual_data[ci[idx % len(ci)]][self._visual_column]
if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img))
else:
img = self.visual_gen.generate(fn, variation_seed=idx * 7 + fl)
if img.mode != "RGB": img = img.convert("RGB")
if self.vit_processor:
pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
else:
a = np.array(img.resize((224,224))).astype(np.float32)/255.0
pv = torch.tensor((a - [.485,.456,.406])/[.229,.224,.225]).permute(2,0,1)
return {"pixel_values": pv, "audio_values": audio_values, "labels": torch.tensor(fl, dtype=torch.long)}
def multimodal_collate_fn(batch):
pv = torch.stack([b["pixel_values"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
al = [b["audio_values"] for b in batch]
ml = max(a.shape[-1] for a in al)
return {"pixel_values": pv, "labels": labels,
"audio_values": torch.stack([torch.nn.functional.pad(a, (0, ml-a.shape[-1])) if a.shape[-1]<ml else a for a in al])}
|