File size: 9,134 Bytes
73d65f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Multimodal PC Fault Detection - Flexible Dataset Module
========================================================
Supports multiple data source modes:
  Audio:  "esc50" (default) | "custom" (any HF audio dataset)
  Visual: "synthetic" (default) | "custom" (any HF image dataset)

See examples at bottom of file or in README.
"""

import os, random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageDraw
from typing import Dict, Optional
import torchaudio.transforms as T
from datasets import load_dataset
from config import (FAULT_CLASSES, ESC50_TO_FAULT, VISUAL_SYNTHESIS, DataConfig, ModelConfig)


class SyntheticVisualGenerator:
    def __init__(self, image_size=224):
        self.image_size, self.templates = image_size, VISUAL_SYNTHESIS
    def generate(self, fault_class_name, variation_seed=0):
        rng = random.Random(variation_seed)
        t = self.templates[fault_class_name]
        br, bg, bb = t["color_dominant"]
        r, g, b = [max(0, min(255, c + rng.randint(-30, 30))) for c in (br, bg, bb)]
        img = Image.new("RGB", (self.image_size, self.image_size), (r, g, b))
        draw = ImageDraw.Draw(img)
        y = rng.randint(10, 40)
        for text in t["text_overlay"]:
            tc = (255,255,255) if (r+g+b)/3 < 128 else (0,0,0)
            draw.text((rng.randint(5, max(10, self.image_size//4)), y), text, fill=tc)
            y += rng.randint(25, 45)
        arr = np.array(img)
        noise = np.random.RandomState(variation_seed).randint(-15, 16, arr.shape, dtype=np.int16)
        return Image.fromarray(np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8))


class AudioPreprocessor:
    def __init__(self, config):
        self.config, self.target_length = config, int(config.sample_rate * config.audio_duration)
        self.mel = T.MelSpectrogram(sample_rate=config.sample_rate, n_fft=config.n_fft,
            hop_length=config.hop_length, n_mels=config.n_mels, f_min=config.fmin, f_max=config.fmax)
        self.to_db = T.AmplitudeToDB(top_db=80)
    def process(self, waveform, sr):
        if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(0, keepdim=True)
        elif waveform.dim() == 1: waveform = waveform.unsqueeze(0)
        if sr != self.config.sample_rate: waveform = T.Resample(sr, self.config.sample_rate)(waveform)
        if waveform.shape[-1] < self.target_length:
            waveform = torch.nn.functional.pad(waveform, (0, self.target_length - waveform.shape[-1]))
        else: waveform = waveform[:, :self.target_length]
        return self.to_db(self.mel(waveform))
    def augment(self, x, training=True):
        if not training: return x
        return T.TimeMasking(self.config.time_mask_max)(T.FrequencyMasking(self.config.freq_mask_max)(x))


class PCFaultDataset(Dataset):
    """
    Flexible multimodal dataset. Works with any HF audio + image dataset.
    
    AUDIO: set audio_source="custom" and provide:
      custom_audio_dataset, custom_audio_column, custom_audio_label_column, custom_audio_label_map
    
    VISUAL: set visual_source="custom" and provide:
      custom_visual_dataset, custom_visual_column, custom_visual_label_column, custom_visual_label_map
    """
    def __init__(self, config, model_config, split="train", vit_processor=None, ast_feature_extractor=None, augment=True,
                 audio_source="esc50", custom_audio_dataset=None, custom_audio_config=None, custom_audio_split=None,
                 custom_audio_column="audio", custom_audio_label_column="label", custom_audio_label_map=None,
                 visual_source="synthetic", custom_visual_dataset=None, custom_visual_config=None, custom_visual_split=None,
                 custom_visual_column="image", custom_visual_label_column="label", custom_visual_label_map=None):
        self.config, self.split, self.augment = config, split, augment and (split == "train")
        self.vit_processor, self.ast_feature_extractor = vit_processor, ast_feature_extractor
        self.audio_source, self.visual_source = audio_source, visual_source
        self.audio_preprocessor = AudioPreprocessor(config)
        self.visual_gen = SyntheticVisualGenerator(config.image_size)
        self._audio_column = custom_audio_column
        
        # Load audio
        if audio_source == "esc50":
            self.audio_data = load_dataset(config.esc50_dataset, split="train")
            self.samples = []
            for idx, row in enumerate(self.audio_data):
                cat = row["category"]
                if cat in ESC50_TO_FAULT:
                    fl, fold = ESC50_TO_FAULT[cat], row["fold"]
                    if (split=="train" and fold!=config.val_fold) or (split=="val" and fold==config.val_fold):
                        self.samples.append({"audio_idx": idx, "fault_label": fl})
        elif audio_source == "custom":
            assert custom_audio_dataset, "custom_audio_dataset required"
            assert custom_audio_label_map, "custom_audio_label_map required: {original_label: 0-4}"
            kw = {"name": custom_audio_config} if custom_audio_config else {}
            self.audio_data = load_dataset(custom_audio_dataset, split=custom_audio_split or split, **kw)
            self.samples = []
            for idx, row in enumerate(self.audio_data):
                lbl = row[custom_audio_label_column]
                mapped = custom_audio_label_map.get(lbl, custom_audio_label_map.get(str(lbl)))
                if mapped is not None:
                    self.samples.append({"audio_idx": idx, "fault_label": mapped})
        
        # Load visual
        self.visual_data, self._visual_by_class = None, {i: [] for i in range(len(FAULT_CLASSES))}
        if visual_source == "custom" and custom_visual_dataset:
            kw = {"name": custom_visual_config} if custom_visual_config else {}
            self.visual_data = load_dataset(custom_visual_dataset, split=custom_visual_split or split, **kw)
            self._visual_column = custom_visual_column
            for idx, row in enumerate(self.visual_data):
                lbl = row.get(custom_visual_label_column)
                mapped = (custom_visual_label_map or {}).get(lbl, (custom_visual_label_map or {}).get(str(lbl)))
                if mapped is not None: self._visual_by_class[mapped].append(idx)
            print(f"[Visual] Custom: {custom_visual_dataset}, per-class: { {k:len(v) for k,v in self._visual_by_class.items()} }")
        
        # Stats
        lc = {}
        for s in self.samples: lc[s["fault_label"]] = lc.get(s["fault_label"], 0) + 1
        print(f"[Dataset] {split} | audio={audio_source} | visual={visual_source} | total={len(self.samples)}")
        for l, c in sorted(lc.items()): print(f"  {FAULT_CLASSES[l]}: {c}")
    
    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        s = self.samples[idx]
        fl, fn = s["fault_label"], FAULT_CLASSES[s["fault_label"]]
        
        # Audio
        row = self.audio_data[s["audio_idx"]]
        ad = row["audio"] if self.audio_source == "esc50" else row[self._audio_column]
        if isinstance(ad, dict) and "array" in ad:
            arr, sr = np.array(ad["array"], dtype=np.float32), ad["sampling_rate"]
        else:
            arr, sr = np.array(ad, dtype=np.float32), self.config.sample_rate
        
        if self.ast_feature_extractor:
            if sr != 16000: arr = T.Resample(sr, 16000)(torch.tensor(arr).unsqueeze(0)).squeeze(0).numpy()
            audio_values = self.ast_feature_extractor(arr, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0)
        else:
            lm = self.audio_preprocessor.process(torch.tensor(arr, dtype=torch.float32), sr)
            if self.augment: lm = self.audio_preprocessor.augment(lm)
            audio_values = lm.squeeze(0)
        
        # Visual
        if self.visual_data and self._visual_by_class.get(fl):
            ci = self._visual_by_class[fl]
            img = self.visual_data[ci[idx % len(ci)]][self._visual_column]
            if not isinstance(img, Image.Image): img = Image.fromarray(np.array(img))
        else:
            img = self.visual_gen.generate(fn, variation_seed=idx * 7 + fl)
        if img.mode != "RGB": img = img.convert("RGB")
        
        if self.vit_processor:
            pv = self.vit_processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
        else:
            a = np.array(img.resize((224,224))).astype(np.float32)/255.0
            pv = torch.tensor((a - [.485,.456,.406])/[.229,.224,.225]).permute(2,0,1)
        
        return {"pixel_values": pv, "audio_values": audio_values, "labels": torch.tensor(fl, dtype=torch.long)}


def multimodal_collate_fn(batch):
    pv = torch.stack([b["pixel_values"] for b in batch])
    labels = torch.stack([b["labels"] for b in batch])
    al = [b["audio_values"] for b in batch]
    ml = max(a.shape[-1] for a in al)
    return {"pixel_values": pv, "labels": labels,
            "audio_values": torch.stack([torch.nn.functional.pad(a, (0, ml-a.shape[-1])) if a.shape[-1]<ml else a for a in al])}