| """ |
| Real Dataset Loader — replaces dataset.py |
| Audio: aegean-ai/engine-anomaly-detection-dataset (ToyADMOS/DCASE) |
| Visual: TheoM55/mvtec_anomaly_detection (MVTec AD) |
| |
| Usage: change one import in train.py: |
| from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn |
| """ |
| import re, random, numpy as np, torch |
| from torch.utils.data import Dataset |
| from PIL import Image |
| from typing import Dict |
| import torchaudio.transforms as T |
| from datasets import load_dataset, concatenate_datasets |
| from collections import Counter |
| from config import FAULT_CLASSES, DataConfig, ModelConfig |
|
|
| class AudioPreprocessor: |
| def __init__(s, c): |
| s.config, s.tl = c, int(c.sample_rate * c.audio_duration) |
| s.mel = T.MelSpectrogram(sample_rate=c.sample_rate, n_fft=c.n_fft, hop_length=c.hop_length, n_mels=c.n_mels, f_min=c.fmin, f_max=c.fmax) |
| s.db = T.AmplitudeToDB(top_db=80) |
| def process(s, w, sr): |
| if w.dim()>1 and w.shape[0]>1: w=w.mean(0,keepdim=True) |
| elif w.dim()==1: w=w.unsqueeze(0) |
| if sr!=s.config.sample_rate: w=T.Resample(sr,s.config.sample_rate)(w) |
| w = torch.nn.functional.pad(w,(0,max(0,s.tl-w.shape[-1])))[:,:s.tl] if w.shape[-1]!=s.tl else w |
| return s.db(s.mel(w)) |
| def augment(s, x): |
| return T.TimeMasking(s.config.time_mask_max)(T.FrequencyMasking(s.config.freq_mask_max)(x)) |
|
|
| def parse_toycar_label(key): |
| m = re.search(r'_(ab\d{2})_|_(normal)_', key) |
| c = (m.group(1) or m.group(2)) if m else "normal" |
| return {"normal":0,"ab01":1,"ab02":1,"ab03":2,"ab04":2,"ab05":3,"ab06":3,"ab07":4,"ab08":4,"ab09":4,"ab10":4}.get(c,0) |
|
|
| def build_mvtec_map(names): |
| m = {} |
| for i, n in enumerate(names): |
| if n=="good": m[i]=0 |
| elif n in ["missing_cable","missing_wire","cut","cut_lead","cut_inner_insulation","cut_outer_insulation","manipulated_front","poke","poke_insulation","cable_swap","combined"]: m[i]=1 |
| elif n in ["contamination","liquid","oil","color","metal_contamination","bent","bent_lead","glue","glue_strip"]: m[i]=2 |
| elif n in ["crack","scratch","scratch_head","scratch_neck","broken_large","broken_small","split_teeth","broken_teeth","rough","thread_side","thread_top","squeeze","fabric_border","fabric_interior"]: m[i]=3 |
| else: m[i]=4 |
| return m |
|
|
| class RealPCFaultDataset(Dataset): |
| def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True, val_ratio=0.2, seed=42): |
| self.config, self.split, self.augment = config, split, augment and (split=="train") |
| self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor |
| self.audio_preprocessor = AudioPreprocessor(config) |
| print(f"\n[Real] Loading audio...") |
| an = load_dataset("aegean-ai/engine-anomaly-detection-dataset", split="train") |
| aa = load_dataset("aegean-ai/engine-anomaly-detection-dataset", split="test") |
| self.audio_labels = [parse_toycar_label(k) for k in an["__key__"]] + [parse_toycar_label(k) for k in aa["__key__"]] |
| self.audio_data = concatenate_datasets([an, aa]) |
| print(f" Audio: {len(self.audio_data)}, dist: {Counter(self.audio_labels)}") |
| print(f"[Real] Loading visual...") |
| vt = load_dataset("TheoM55/mvtec_anomaly_detection", split="train") |
| ve = load_dataset("TheoM55/mvtec_anomaly_detection", split="test") |
| mm = build_mvtec_map(ve.features["label"].names) |
| self.visual_labels = [0]*len(vt) + [mm.get(l,4) for l in ve["label"]] |
| self.visual_data = concatenate_datasets([vt, ve]) |
| print(f" Visual: {len(self.visual_data)}, dist: {Counter(self.visual_labels)}") |
| abc = {i:[] for i in range(5)} |
| for i,l in enumerate(self.audio_labels): abc[l].append(i) |
| vbc = {i:[] for i in range(5)} |
| for i,l in enumerate(self.visual_labels): vbc[l].append(i) |
| rng, pairs = random.Random(seed), [] |
| for c in range(5): |
| a,v = abc[c], vbc[c] |
| if not a or not v: print(f" WARN: class {c} missing data"); continue |
| for i in range(max(len(a),len(v))): pairs.append({"audio_idx":a[i%len(a)],"visual_idx":v[i%len(v)],"fault_label":c}) |
| rng.shuffle(pairs) |
| si = int(len(pairs)*(1-val_ratio)) |
| self.samples = pairs[:si] if split=="train" else pairs[si:] |
| lc = Counter(s["fault_label"] for s in self.samples) |
| print(f"\n[Real] {split}: {len(self.samples)} pairs") |
| for l in range(5): print(f" {FAULT_CLASSES[l]}: {lc.get(l,0)}") |
|
|
| def __len__(self): return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| s = self.samples[idx]; fl = s["fault_label"] |
| ad = self.audio_data[s["audio_idx"]]["wav"] |
| if isinstance(ad,dict) and "array" in ad: arr,sr = np.array(ad["array"],dtype=np.float32), ad["sampling_rate"] |
| else: arr,sr = np.array(ad,dtype=np.float32), self.config.sample_rate |
| if self.ast_feature_extractor: |
| if sr!=16000: arr = T.Resample(sr,16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy() |
| av = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0) |
| else: |
| lm = self.audio_preprocessor.process(torch.tensor(arr,dtype=torch.float32), sr) |
| if self.augment: lm = self.audio_preprocessor.augment(lm) |
| av = lm.squeeze(0) |
| img = self.visual_data[s["visual_idx"]]["image"] |
| if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img)) |
| if img.mode!="RGB": img = img.convert("RGB") |
| if self.vit_processor: pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0) |
| else: |
| a = np.array(img.resize((224,224))).astype(np.float32)/255.0 |
| pv = torch.tensor((a-[.485,.456,.406])/[.229,.224,.225]).permute(2,0,1) |
| return {"pixel_values": pv, "audio_values": av, "labels": torch.tensor(fl, dtype=torch.long)} |
|
|
| def multimodal_collate_fn(batch): |
| pv = torch.stack([b["pixel_values"] for b in batch]) |
| labels = torch.stack([b["labels"] for b in batch]) |
| al = [b["audio_values"] for b in batch]; ml = max(a.shape[-1] for a in al) |
| return {"pixel_values":pv,"labels":labels,"audio_values":torch.stack([torch.nn.functional.pad(a,(0,ml-a.shape[-1])) if a.shape[-1]<ml else a for a in al])} |
|
|