multimodal-pc-fault-detector / src /dataset_real.py
Ellaft's picture
Add real dataset loader (ToyADMOS audio + MVTec AD visual)
ad79a0c verified
"""
Real Dataset Loader — replaces dataset.py
Audio: aegean-ai/engine-anomaly-detection-dataset (ToyADMOS/DCASE)
Visual: TheoM55/mvtec_anomaly_detection (MVTec AD)
Usage: change one import in train.py:
from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn
"""
import re, random, numpy as np, torch
from torch.utils.data import Dataset
from PIL import Image
from typing import Dict
import torchaudio.transforms as T
from datasets import load_dataset, concatenate_datasets
from collections import Counter
from config import FAULT_CLASSES, DataConfig, ModelConfig
class AudioPreprocessor:
def __init__(s, c):
s.config, s.tl = c, int(c.sample_rate * c.audio_duration)
s.mel = T.MelSpectrogram(sample_rate=c.sample_rate, n_fft=c.n_fft, hop_length=c.hop_length, n_mels=c.n_mels, f_min=c.fmin, f_max=c.fmax)
s.db = T.AmplitudeToDB(top_db=80)
def process(s, w, sr):
if w.dim()>1 and w.shape[0]>1: w=w.mean(0,keepdim=True)
elif w.dim()==1: w=w.unsqueeze(0)
if sr!=s.config.sample_rate: w=T.Resample(sr,s.config.sample_rate)(w)
w = torch.nn.functional.pad(w,(0,max(0,s.tl-w.shape[-1])))[:,:s.tl] if w.shape[-1]!=s.tl else w
return s.db(s.mel(w))
def augment(s, x):
return T.TimeMasking(s.config.time_mask_max)(T.FrequencyMasking(s.config.freq_mask_max)(x))
def parse_toycar_label(key):
m = re.search(r'_(ab\d{2})_|_(normal)_', key)
c = (m.group(1) or m.group(2)) if m else "normal"
return {"normal":0,"ab01":1,"ab02":1,"ab03":2,"ab04":2,"ab05":3,"ab06":3,"ab07":4,"ab08":4,"ab09":4,"ab10":4}.get(c,0)
def build_mvtec_map(names):
m = {}
for i, n in enumerate(names):
if n=="good": m[i]=0
elif n in ["missing_cable","missing_wire","cut","cut_lead","cut_inner_insulation","cut_outer_insulation","manipulated_front","poke","poke_insulation","cable_swap","combined"]: m[i]=1
elif n in ["contamination","liquid","oil","color","metal_contamination","bent","bent_lead","glue","glue_strip"]: m[i]=2
elif n in ["crack","scratch","scratch_head","scratch_neck","broken_large","broken_small","split_teeth","broken_teeth","rough","thread_side","thread_top","squeeze","fabric_border","fabric_interior"]: m[i]=3
else: m[i]=4
return m
class RealPCFaultDataset(Dataset):
def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True, val_ratio=0.2, seed=42):
self.config, self.split, self.augment = config, split, augment and (split=="train")
self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor
self.audio_preprocessor = AudioPreprocessor(config)
print(f"\n[Real] Loading audio...")
an = load_dataset("aegean-ai/engine-anomaly-detection-dataset", split="train")
aa = load_dataset("aegean-ai/engine-anomaly-detection-dataset", split="test")
self.audio_labels = [parse_toycar_label(k) for k in an["__key__"]] + [parse_toycar_label(k) for k in aa["__key__"]]
self.audio_data = concatenate_datasets([an, aa])
print(f" Audio: {len(self.audio_data)}, dist: {Counter(self.audio_labels)}")
print(f"[Real] Loading visual...")
vt = load_dataset("TheoM55/mvtec_anomaly_detection", split="train")
ve = load_dataset("TheoM55/mvtec_anomaly_detection", split="test")
mm = build_mvtec_map(ve.features["label"].names)
self.visual_labels = [0]*len(vt) + [mm.get(l,4) for l in ve["label"]]
self.visual_data = concatenate_datasets([vt, ve])
print(f" Visual: {len(self.visual_data)}, dist: {Counter(self.visual_labels)}")
abc = {i:[] for i in range(5)}
for i,l in enumerate(self.audio_labels): abc[l].append(i)
vbc = {i:[] for i in range(5)}
for i,l in enumerate(self.visual_labels): vbc[l].append(i)
rng, pairs = random.Random(seed), []
for c in range(5):
a,v = abc[c], vbc[c]
if not a or not v: print(f" WARN: class {c} missing data"); continue
for i in range(max(len(a),len(v))): pairs.append({"audio_idx":a[i%len(a)],"visual_idx":v[i%len(v)],"fault_label":c})
rng.shuffle(pairs)
si = int(len(pairs)*(1-val_ratio))
self.samples = pairs[:si] if split=="train" else pairs[si:]
lc = Counter(s["fault_label"] for s in self.samples)
print(f"\n[Real] {split}: {len(self.samples)} pairs")
for l in range(5): print(f" {FAULT_CLASSES[l]}: {lc.get(l,0)}")
def __len__(self): return len(self.samples)
def __getitem__(self, idx):
s = self.samples[idx]; fl = s["fault_label"]
ad = self.audio_data[s["audio_idx"]]["wav"]
if isinstance(ad,dict) and "array" in ad: arr,sr = np.array(ad["array"],dtype=np.float32), ad["sampling_rate"]
else: arr,sr = np.array(ad,dtype=np.float32), self.config.sample_rate
if self.ast_feature_extractor:
if sr!=16000: arr = T.Resample(sr,16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy()
av = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0)
else:
lm = self.audio_preprocessor.process(torch.tensor(arr,dtype=torch.float32), sr)
if self.augment: lm = self.audio_preprocessor.augment(lm)
av = lm.squeeze(0)
img = self.visual_data[s["visual_idx"]]["image"]
if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img))
if img.mode!="RGB": img = img.convert("RGB")
if self.vit_processor: pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
else:
a = np.array(img.resize((224,224))).astype(np.float32)/255.0
pv = torch.tensor((a-[.485,.456,.406])/[.229,.224,.225]).permute(2,0,1)
return {"pixel_values": pv, "audio_values": av, "labels": torch.tensor(fl, dtype=torch.long)}
def multimodal_collate_fn(batch):
pv = torch.stack([b["pixel_values"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
al = [b["audio_values"] for b in batch]; ml = max(a.shape[-1] for a in al)
return {"pixel_values":pv,"labels":labels,"audio_values":torch.stack([torch.nn.functional.pad(a,(0,ml-a.shape[-1])) if a.shape[-1]<ml else a for a in al])}