""" Real Dataset Loader — replaces dataset.py Audio: aegean-ai/engine-anomaly-detection-dataset (ToyADMOS/DCASE) Visual: TheoM55/mvtec_anomaly_detection (MVTec AD) Usage: change one import in train.py: from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn """ import re, random, numpy as np, torch from torch.utils.data import Dataset from PIL import Image from typing import Dict import torchaudio.transforms as T from datasets import load_dataset, concatenate_datasets from collections import Counter from config import FAULT_CLASSES, DataConfig, ModelConfig class AudioPreprocessor: def __init__(s, c): s.config, s.tl = c, int(c.sample_rate * c.audio_duration) s.mel = T.MelSpectrogram(sample_rate=c.sample_rate, n_fft=c.n_fft, hop_length=c.hop_length, n_mels=c.n_mels, f_min=c.fmin, f_max=c.fmax) s.db = T.AmplitudeToDB(top_db=80) def process(s, w, sr): if w.dim()>1 and w.shape[0]>1: w=w.mean(0,keepdim=True) elif w.dim()==1: w=w.unsqueeze(0) if sr!=s.config.sample_rate: w=T.Resample(sr,s.config.sample_rate)(w) w = torch.nn.functional.pad(w,(0,max(0,s.tl-w.shape[-1])))[:,:s.tl] if w.shape[-1]!=s.tl else w return s.db(s.mel(w)) def augment(s, x): return T.TimeMasking(s.config.time_mask_max)(T.FrequencyMasking(s.config.freq_mask_max)(x)) def parse_toycar_label(key): m = re.search(r'_(ab\d{2})_|_(normal)_', key) c = (m.group(1) or m.group(2)) if m else "normal" return {"normal":0,"ab01":1,"ab02":1,"ab03":2,"ab04":2,"ab05":3,"ab06":3,"ab07":4,"ab08":4,"ab09":4,"ab10":4}.get(c,0) def build_mvtec_map(names): m = {} for i, n in enumerate(names): if n=="good": m[i]=0 elif n in ["missing_cable","missing_wire","cut","cut_lead","cut_inner_insulation","cut_outer_insulation","manipulated_front","poke","poke_insulation","cable_swap","combined"]: m[i]=1 elif n in ["contamination","liquid","oil","color","metal_contamination","bent","bent_lead","glue","glue_strip"]: m[i]=2 elif n in ["crack","scratch","scratch_head","scratch_neck","broken_large","broken_small","split_teeth","broken_teeth","rough","thread_side","thread_top","squeeze","fabric_border","fabric_interior"]: m[i]=3 else: m[i]=4 return m class RealPCFaultDataset(Dataset): def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True, val_ratio=0.2, seed=42): self.config, self.split, self.augment = config, split, augment and (split=="train") self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor self.audio_preprocessor = AudioPreprocessor(config) print(f"\n[Real] Loading audio...") an = load_dataset("aegean-ai/engine-anomaly-detection-dataset", split="train") aa = load_dataset("aegean-ai/engine-anomaly-detection-dataset", split="test") self.audio_labels = [parse_toycar_label(k) for k in an["__key__"]] + [parse_toycar_label(k) for k in aa["__key__"]] self.audio_data = concatenate_datasets([an, aa]) print(f" Audio: {len(self.audio_data)}, dist: {Counter(self.audio_labels)}") print(f"[Real] Loading visual...") vt = load_dataset("TheoM55/mvtec_anomaly_detection", split="train") ve = load_dataset("TheoM55/mvtec_anomaly_detection", split="test") mm = build_mvtec_map(ve.features["label"].names) self.visual_labels = [0]*len(vt) + [mm.get(l,4) for l in ve["label"]] self.visual_data = concatenate_datasets([vt, ve]) print(f" Visual: {len(self.visual_data)}, dist: {Counter(self.visual_labels)}") abc = {i:[] for i in range(5)} for i,l in enumerate(self.audio_labels): abc[l].append(i) vbc = {i:[] for i in range(5)} for i,l in enumerate(self.visual_labels): vbc[l].append(i) rng, pairs = random.Random(seed), [] for c in range(5): a,v = abc[c], vbc[c] if not a or not v: print(f" WARN: class {c} missing data"); continue for i in range(max(len(a),len(v))): pairs.append({"audio_idx":a[i%len(a)],"visual_idx":v[i%len(v)],"fault_label":c}) rng.shuffle(pairs) si = int(len(pairs)*(1-val_ratio)) self.samples = pairs[:si] if split=="train" else pairs[si:] lc = Counter(s["fault_label"] for s in self.samples) print(f"\n[Real] {split}: {len(self.samples)} pairs") for l in range(5): print(f" {FAULT_CLASSES[l]}: {lc.get(l,0)}") def __len__(self): return len(self.samples) def __getitem__(self, idx): s = self.samples[idx]; fl = s["fault_label"] ad = self.audio_data[s["audio_idx"]]["wav"] if isinstance(ad,dict) and "array" in ad: arr,sr = np.array(ad["array"],dtype=np.float32), ad["sampling_rate"] else: arr,sr = np.array(ad,dtype=np.float32), self.config.sample_rate if self.ast_feature_extractor: if sr!=16000: arr = T.Resample(sr,16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy() av = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0) else: lm = self.audio_preprocessor.process(torch.tensor(arr,dtype=torch.float32), sr) if self.augment: lm = self.audio_preprocessor.augment(lm) av = lm.squeeze(0) img = self.visual_data[s["visual_idx"]]["image"] if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img)) if img.mode!="RGB": img = img.convert("RGB") if self.vit_processor: pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0) else: a = np.array(img.resize((224,224))).astype(np.float32)/255.0 pv = torch.tensor((a-[.485,.456,.406])/[.229,.224,.225]).permute(2,0,1) return {"pixel_values": pv, "audio_values": av, "labels": torch.tensor(fl, dtype=torch.long)} def multimodal_collate_fn(batch): pv = torch.stack([b["pixel_values"] for b in batch]) labels = torch.stack([b["labels"] for b in batch]) al = [b["audio_values"] for b in batch]; ml = max(a.shape[-1] for a in al) return {"pixel_values":pv,"labels":labels,"audio_values":torch.stack([torch.nn.functional.pad(a,(0,ml-a.shape[-1])) if a.shape[-1]