File size: 6,363 Bytes
ad79a0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Real Dataset Loader — replaces dataset.py
Audio: aegean-ai/engine-anomaly-detection-dataset (ToyADMOS/DCASE)
Visual: TheoM55/mvtec_anomaly_detection (MVTec AD)

Usage: change one import in train.py:
  from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn
"""
import re, random, numpy as np, torch
from torch.utils.data import Dataset
from PIL import Image
from typing import Dict
import torchaudio.transforms as T
from datasets import load_dataset, concatenate_datasets
from collections import Counter
from config import FAULT_CLASSES, DataConfig, ModelConfig

class AudioPreprocessor:
    def __init__(s, c):
        s.config, s.tl = c, int(c.sample_rate * c.audio_duration)
        s.mel = T.MelSpectrogram(sample_rate=c.sample_rate, n_fft=c.n_fft, hop_length=c.hop_length, n_mels=c.n_mels, f_min=c.fmin, f_max=c.fmax)
        s.db = T.AmplitudeToDB(top_db=80)
    def process(s, w, sr):
        if w.dim()>1 and w.shape[0]>1: w=w.mean(0,keepdim=True)
        elif w.dim()==1: w=w.unsqueeze(0)
        if sr!=s.config.sample_rate: w=T.Resample(sr,s.config.sample_rate)(w)
        w = torch.nn.functional.pad(w,(0,max(0,s.tl-w.shape[-1])))[:,:s.tl] if w.shape[-1]!=s.tl else w
        return s.db(s.mel(w))
    def augment(s, x):
        return T.TimeMasking(s.config.time_mask_max)(T.FrequencyMasking(s.config.freq_mask_max)(x))

def parse_toycar_label(key):
    m = re.search(r'_(ab\d{2})_|_(normal)_', key)
    c = (m.group(1) or m.group(2)) if m else "normal"
    return {"normal":0,"ab01":1,"ab02":1,"ab03":2,"ab04":2,"ab05":3,"ab06":3,"ab07":4,"ab08":4,"ab09":4,"ab10":4}.get(c,0)

def build_mvtec_map(names):
    m = {}
    for i, n in enumerate(names):
        if n=="good": m[i]=0
        elif n in ["missing_cable","missing_wire","cut","cut_lead","cut_inner_insulation","cut_outer_insulation","manipulated_front","poke","poke_insulation","cable_swap","combined"]: m[i]=1
        elif n in ["contamination","liquid","oil","color","metal_contamination","bent","bent_lead","glue","glue_strip"]: m[i]=2
        elif n in ["crack","scratch","scratch_head","scratch_neck","broken_large","broken_small","split_teeth","broken_teeth","rough","thread_side","thread_top","squeeze","fabric_border","fabric_interior"]: m[i]=3
        else: m[i]=4
    return m

class RealPCFaultDataset(Dataset):
    def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True, val_ratio=0.2, seed=42):
        self.config, self.split, self.augment = config, split, augment and (split=="train")
        self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor
        self.audio_preprocessor = AudioPreprocessor(config)
        print(f"\n[Real] Loading audio...")
        an = load_dataset("aegean-ai/engine-anomaly-detection-dataset", split="train")
        aa = load_dataset("aegean-ai/engine-anomaly-detection-dataset", split="test")
        self.audio_labels = [parse_toycar_label(k) for k in an["__key__"]] + [parse_toycar_label(k) for k in aa["__key__"]]
        self.audio_data = concatenate_datasets([an, aa])
        print(f"  Audio: {len(self.audio_data)}, dist: {Counter(self.audio_labels)}")
        print(f"[Real] Loading visual...")
        vt = load_dataset("TheoM55/mvtec_anomaly_detection", split="train")
        ve = load_dataset("TheoM55/mvtec_anomaly_detection", split="test")
        mm = build_mvtec_map(ve.features["label"].names)
        self.visual_labels = [0]*len(vt) + [mm.get(l,4) for l in ve["label"]]
        self.visual_data = concatenate_datasets([vt, ve])
        print(f"  Visual: {len(self.visual_data)}, dist: {Counter(self.visual_labels)}")
        abc = {i:[] for i in range(5)}
        for i,l in enumerate(self.audio_labels): abc[l].append(i)
        vbc = {i:[] for i in range(5)}
        for i,l in enumerate(self.visual_labels): vbc[l].append(i)
        rng, pairs = random.Random(seed), []
        for c in range(5):
            a,v = abc[c], vbc[c]
            if not a or not v: print(f"  WARN: class {c} missing data"); continue
            for i in range(max(len(a),len(v))): pairs.append({"audio_idx":a[i%len(a)],"visual_idx":v[i%len(v)],"fault_label":c})
        rng.shuffle(pairs)
        si = int(len(pairs)*(1-val_ratio))
        self.samples = pairs[:si] if split=="train" else pairs[si:]
        lc = Counter(s["fault_label"] for s in self.samples)
        print(f"\n[Real] {split}: {len(self.samples)} pairs")
        for l in range(5): print(f"  {FAULT_CLASSES[l]}: {lc.get(l,0)}")

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]; fl = s["fault_label"]
        ad = self.audio_data[s["audio_idx"]]["wav"]
        if isinstance(ad,dict) and "array" in ad: arr,sr = np.array(ad["array"],dtype=np.float32), ad["sampling_rate"]
        else: arr,sr = np.array(ad,dtype=np.float32), self.config.sample_rate
        if self.ast_feature_extractor:
            if sr!=16000: arr = T.Resample(sr,16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy()
            av = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0)
        else:
            lm = self.audio_preprocessor.process(torch.tensor(arr,dtype=torch.float32), sr)
            if self.augment: lm = self.audio_preprocessor.augment(lm)
            av = lm.squeeze(0)
        img = self.visual_data[s["visual_idx"]]["image"]
        if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img))
        if img.mode!="RGB": img = img.convert("RGB")
        if self.vit_processor: pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
        else:
            a = np.array(img.resize((224,224))).astype(np.float32)/255.0
            pv = torch.tensor((a-[.485,.456,.406])/[.229,.224,.225]).permute(2,0,1)
        return {"pixel_values": pv, "audio_values": av, "labels": torch.tensor(fl, dtype=torch.long)}

def multimodal_collate_fn(batch):
    pv = torch.stack([b["pixel_values"] for b in batch])
    labels = torch.stack([b["labels"] for b in batch])
    al = [b["audio_values"] for b in batch]; ml = max(a.shape[-1] for a in al)
    return {"pixel_values":pv,"labels":labels,"audio_values":torch.stack([torch.nn.functional.pad(a,(0,ml-a.shape[-1])) if a.shape[-1]<ml else a for a in al])}