Ellaft commited on
Commit
c29559f
·
verified ·
1 Parent(s): ddf61da

Add dataset_v2.py: adapter for build_dataset.py output, drop-in replacement for dataset_real.py

Browse files
Files changed (1) hide show
  1. src/dataset_v2.py +433 -0
src/dataset_v2.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Loader v2 — Loads data built by build_dataset.py
3
+ ==========================================================
4
+ Drop-in replacement for dataset_real.py. Loads from either:
5
+ 1. Local manifest (dataset_build/dataset_manifest.json) — from build_dataset.py
6
+ 2. HuggingFace Hub dataset (Ellaft/pc-fault-real-dataset) — if uploaded
7
+
8
+ Data sources: YouTube scraped audio/frames, HF cooling-fan recordings,
9
+ synthetic BIOS beep codes, synthetic HDD clicks, synthetic BSOD/POST/thermal images.
10
+
11
+ Usage — just change one import in train_v2.py:
12
+ from dataset_v2 import BuiltDataset as PCFaultDataset, multimodal_collate_fn
13
+
14
+ Or run train_v2.py with --dataset flag:
15
+ python train_v2.py --dataset local --dataset_dir ./dataset_build
16
+ python train_v2.py --dataset hub --hub_dataset Ellaft/pc-fault-real-dataset
17
+ """
18
+
19
+ import os, json, random, glob
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import Dataset
24
+ from PIL import Image
25
+ from pathlib import Path
26
+ from collections import Counter
27
+ from typing import Optional
28
+
29
+ try:
30
+ import torchaudio.transforms as T
31
+ HAS_TORCHAUDIO = True
32
+ except ImportError:
33
+ HAS_TORCHAUDIO = False
34
+
35
+ try:
36
+ import soundfile as sf
37
+ HAS_SOUNDFILE = True
38
+ except ImportError:
39
+ HAS_SOUNDFILE = False
40
+
41
+ try:
42
+ import librosa
43
+ HAS_LIBROSA = True
44
+ except ImportError:
45
+ HAS_LIBROSA = False
46
+
47
+ try:
48
+ from config import FAULT_CLASSES, DataConfig, ModelConfig
49
+ except ImportError:
50
+ # Standalone mode — define fault classes inline
51
+ FAULT_CLASSES = [
52
+ "normal_operation", "boot_failure", "overheating_fan",
53
+ "storage_failure", "system_crash",
54
+ ]
55
+ DataConfig = None
56
+ ModelConfig = None
57
+
58
+
59
+ # ============================================================================
60
+ # Audio loading helpers
61
+ # ============================================================================
62
+
63
+ def load_audio_file(path, target_sr=16000):
64
+ """Load a WAV file and return (numpy_array, sample_rate)."""
65
+ if HAS_SOUNDFILE:
66
+ arr, sr = sf.read(path, dtype="float32")
67
+ if arr.ndim > 1:
68
+ arr = arr.mean(axis=1) # mono
69
+ return arr, sr
70
+ elif HAS_LIBROSA:
71
+ arr, sr = librosa.load(path, sr=target_sr, mono=True)
72
+ return arr, sr
73
+ elif HAS_TORCHAUDIO:
74
+ import torchaudio
75
+ waveform, sr = torchaudio.load(path)
76
+ if waveform.shape[0] > 1:
77
+ waveform = waveform.mean(dim=0, keepdim=True)
78
+ return waveform.squeeze(0).numpy(), sr
79
+ else:
80
+ raise ImportError("Need soundfile, librosa, or torchaudio to load audio. "
81
+ "Install: pip install soundfile")
82
+
83
+
84
+ def resample_audio(arr, orig_sr, target_sr=16000):
85
+ """Resample audio array to target sample rate."""
86
+ if orig_sr == target_sr:
87
+ return arr
88
+ if HAS_TORCHAUDIO:
89
+ resampler = T.Resample(orig_sr, target_sr)
90
+ tensor = torch.tensor(arr, dtype=torch.float32).unsqueeze(0)
91
+ return resampler(tensor).squeeze(0).numpy()
92
+ elif HAS_LIBROSA:
93
+ return librosa.resample(arr, orig_sr=orig_sr, target_sr=target_sr)
94
+ else:
95
+ # Simple linear interpolation fallback
96
+ ratio = target_sr / orig_sr
97
+ new_len = int(len(arr) * ratio)
98
+ indices = np.linspace(0, len(arr) - 1, new_len)
99
+ return np.interp(indices, np.arange(len(arr)), arr).astype(np.float32)
100
+
101
+
102
+ # ============================================================================
103
+ # Main Dataset Class
104
+ # ============================================================================
105
+
106
+ class BuiltDataset(Dataset):
107
+ """
108
+ Loads multimodal PC fault dataset from build_dataset.py output.
109
+
110
+ Matches the exact interface of RealPCFaultDataset so train_v2.py works
111
+ without any changes — just swap the import.
112
+
113
+ Supports two modes:
114
+ - "local": Load from manifest JSON + local files (default)
115
+ - "hub": Load from HuggingFace Hub dataset
116
+ """
117
+
118
+ def __init__(self, config, model_config, split="train",
119
+ vit_processor=None, ast_feature_extractor=None,
120
+ augment=True, val_ratio=0.15, test_ratio=0.15, seed=42,
121
+ # New parameters for v2 dataset
122
+ source="local", # "local" or "hub"
123
+ dataset_dir="./dataset_build",
124
+ hub_dataset="Ellaft/pc-fault-real-dataset"):
125
+ """
126
+ Args:
127
+ config: DataConfig instance
128
+ model_config: ModelConfig instance (unused, kept for compat)
129
+ split: "train", "val", or "test"
130
+ vit_processor: ViT image processor
131
+ ast_feature_extractor: AST feature extractor
132
+ augment: Whether to apply data augmentation (train only)
133
+ val_ratio: Validation split ratio (for local mode)
134
+ test_ratio: Test split ratio (for local mode)
135
+ seed: Random seed for reproducibility
136
+ source: "local" (manifest files) or "hub" (HF dataset)
137
+ dataset_dir: Path to build_dataset.py output (local mode)
138
+ hub_dataset: HuggingFace dataset ID (hub mode)
139
+ """
140
+ self.config = config
141
+ self.split = split
142
+ self.augment = augment and (split == "train")
143
+ self.vit_processor = vit_processor
144
+ self.ast_feature_extractor = ast_feature_extractor
145
+ self.target_sr = 16000 # AST expects 16kHz
146
+ self.audio_duration = config.audio_duration # seconds
147
+ self.target_audio_len = int(self.target_sr * self.audio_duration)
148
+
149
+ if source == "hub":
150
+ self._load_from_hub(hub_dataset, split, seed)
151
+ else:
152
+ self._load_from_local(dataset_dir, split, val_ratio, test_ratio, seed)
153
+
154
+ # Print statistics
155
+ lc = Counter(s["fault_label"] for s in self.samples)
156
+ n_has_audio = sum(1 for s in self.samples if s.get("audio_path") or s.get("audio_data") is not None)
157
+ n_has_image = sum(1 for s in self.samples if s.get("image_path") or s.get("image_data") is not None)
158
+ print(f"\n[BuiltDataset] {split}: {len(self.samples)} samples "
159
+ f"(audio: {n_has_audio}, images: {n_has_image})")
160
+ for label_id in range(5):
161
+ print(f" {FAULT_CLASSES[label_id]}: {lc.get(label_id, 0)}")
162
+
163
+ def _load_from_local(self, dataset_dir, split, val_ratio, test_ratio, seed):
164
+ """Load from build_dataset.py manifest."""
165
+ dataset_dir = Path(dataset_dir)
166
+ manifest_path = dataset_dir / "dataset_manifest.json"
167
+
168
+ if not manifest_path.exists():
169
+ raise FileNotFoundError(
170
+ f"Dataset manifest not found at {manifest_path}\n"
171
+ f"Run build_dataset.py first:\n"
172
+ f" cd data && python build_dataset.py --max_per_class 300")
173
+
174
+ print(f"[BuiltDataset] Loading from {manifest_path}")
175
+ with open(manifest_path) as f:
176
+ manifest = json.load(f)
177
+
178
+ all_samples = manifest["samples"]
179
+ print(f" Total samples in manifest: {len(all_samples)}")
180
+
181
+ # Convert manifest format to our internal format
182
+ samples = []
183
+ for s in all_samples:
184
+ samples.append({
185
+ "fault_label": s["fault_class"],
186
+ "audio_path": s.get("audio_path"),
187
+ "image_path": s.get("image_path"),
188
+ })
189
+
190
+ # Stratified split
191
+ rng = random.Random(seed)
192
+ by_class = {i: [] for i in range(5)}
193
+ for s in samples:
194
+ by_class[s["fault_label"]].append(s)
195
+
196
+ train_samples, val_samples, test_samples = [], [], []
197
+ for cls_id, cls_samples in by_class.items():
198
+ rng.shuffle(cls_samples)
199
+ n = len(cls_samples)
200
+ n_test = max(1, int(n * test_ratio))
201
+ n_val = max(1, int(n * val_ratio))
202
+ n_train = n - n_val - n_test
203
+
204
+ test_samples.extend(cls_samples[:n_test])
205
+ val_samples.extend(cls_samples[n_test:n_test + n_val])
206
+ train_samples.extend(cls_samples[n_test + n_val:])
207
+
208
+ if split == "train":
209
+ self.samples = train_samples
210
+ elif split in ("val", "validation"):
211
+ self.samples = val_samples
212
+ else:
213
+ self.samples = test_samples
214
+
215
+ rng.shuffle(self.samples)
216
+
217
+ def _load_from_hub(self, hub_dataset, split, seed):
218
+ """Load from HuggingFace Hub dataset."""
219
+ from datasets import load_dataset
220
+
221
+ # Map our split names to Hub split names
222
+ hub_split = {"val": "validation", "validation": "validation",
223
+ "train": "train", "test": "test"}.get(split, split)
224
+
225
+ print(f"[BuiltDataset] Loading from Hub: {hub_dataset} (split={hub_split})")
226
+ ds = load_dataset(hub_dataset, split=hub_split)
227
+ print(f" Loaded {len(ds)} samples")
228
+
229
+ self.hub_data = ds
230
+ self.samples = []
231
+ for i in range(len(ds)):
232
+ self.samples.append({
233
+ "fault_label": ds[i]["fault_class"],
234
+ "hub_idx": i,
235
+ # Audio/image are loaded lazily from Hub dataset
236
+ "audio_data": ds[i].get("audio"),
237
+ "image_data": ds[i].get("image"),
238
+ })
239
+
240
+ def __len__(self):
241
+ return len(self.samples)
242
+
243
+ def __getitem__(self, idx):
244
+ s = self.samples[idx]
245
+ fault_label = s["fault_label"]
246
+
247
+ # ---- Load Audio ----
248
+ audio_values = self._load_audio(s)
249
+
250
+ # ---- Load Image ----
251
+ pixel_values = self._load_image(s)
252
+
253
+ return {
254
+ "pixel_values": pixel_values,
255
+ "audio_values": audio_values,
256
+ "labels": torch.tensor(fault_label, dtype=torch.long),
257
+ }
258
+
259
+ def _load_audio(self, sample):
260
+ """Load and process audio into AST-compatible format."""
261
+ arr = None
262
+ sr = self.target_sr
263
+
264
+ # Try Hub data first
265
+ if "audio_data" in sample and sample["audio_data"] is not None:
266
+ audio_data = sample["audio_data"]
267
+ if isinstance(audio_data, dict):
268
+ arr = np.array(audio_data["array"], dtype=np.float32)
269
+ sr = audio_data.get("sampling_rate", self.target_sr)
270
+ elif isinstance(audio_data, np.ndarray):
271
+ arr = audio_data.astype(np.float32)
272
+
273
+ # Try local file
274
+ elif sample.get("audio_path") and os.path.exists(sample["audio_path"]):
275
+ try:
276
+ arr, sr = load_audio_file(sample["audio_path"], self.target_sr)
277
+ except Exception as e:
278
+ print(f" ⚠ Failed to load audio {sample['audio_path']}: {e}")
279
+ arr = None
280
+
281
+ # Fallback: generate silence (model still gets image)
282
+ if arr is None:
283
+ arr = np.zeros(self.target_audio_len, dtype=np.float32)
284
+ sr = self.target_sr
285
+
286
+ # Ensure float32
287
+ arr = arr.astype(np.float32)
288
+
289
+ # Resample to 16kHz for AST
290
+ if sr != self.target_sr:
291
+ arr = resample_audio(arr, sr, self.target_sr)
292
+
293
+ # Pad/trim to target duration
294
+ if len(arr) < self.target_audio_len:
295
+ arr = np.pad(arr, (0, self.target_audio_len - len(arr)))
296
+ elif len(arr) > self.target_audio_len:
297
+ # Random crop during training, center crop during eval
298
+ if self.augment:
299
+ start = random.randint(0, len(arr) - self.target_audio_len)
300
+ else:
301
+ start = (len(arr) - self.target_audio_len) // 2
302
+ arr = arr[start:start + self.target_audio_len]
303
+
304
+ # Data augmentation (training only)
305
+ if self.augment:
306
+ arr = self._augment_audio(arr)
307
+
308
+ # Process with AST feature extractor
309
+ if self.ast_feature_extractor:
310
+ inputs = self.ast_feature_extractor(
311
+ arr, sampling_rate=self.target_sr,
312
+ return_tensors="pt")
313
+ audio_values = inputs["input_values"].squeeze(0)
314
+ else:
315
+ # Fallback: raw waveform tensor
316
+ audio_values = torch.tensor(arr, dtype=torch.float32)
317
+
318
+ return audio_values
319
+
320
+ def _load_image(self, sample):
321
+ """Load and process image into ViT-compatible format."""
322
+ img = None
323
+
324
+ # Try Hub data first
325
+ if "image_data" in sample and sample["image_data"] is not None:
326
+ img = sample["image_data"]
327
+ if not isinstance(img, Image.Image):
328
+ try:
329
+ img = Image.fromarray(np.array(img))
330
+ except Exception:
331
+ img = None
332
+
333
+ # Try local file
334
+ elif sample.get("image_path") and os.path.exists(sample["image_path"]):
335
+ try:
336
+ img = Image.open(sample["image_path"])
337
+ except Exception as e:
338
+ print(f" ⚠ Failed to load image {sample['image_path']}: {e}")
339
+ img = None
340
+
341
+ # Fallback: black image
342
+ if img is None:
343
+ img = Image.new("RGB", (224, 224), color=(0, 0, 0))
344
+
345
+ # Ensure RGB
346
+ if img.mode != "RGB":
347
+ img = img.convert("RGB")
348
+
349
+ # Data augmentation (training only)
350
+ if self.augment:
351
+ img = self._augment_image(img)
352
+
353
+ # Process with ViT processor
354
+ if self.vit_processor:
355
+ pixel_values = self.vit_processor(
356
+ images=img, return_tensors="pt")["pixel_values"].squeeze(0)
357
+ else:
358
+ # Manual normalization fallback
359
+ arr = np.array(img.resize((224, 224))).astype(np.float32) / 255.0
360
+ arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
361
+ pixel_values = torch.tensor(arr, dtype=torch.float32).permute(2, 0, 1)
362
+
363
+ return pixel_values
364
+
365
+ def _augment_audio(self, arr):
366
+ """Audio augmentation: noise injection, time shift, gain variation."""
367
+ # Random gain
368
+ if random.random() < 0.5:
369
+ gain = random.uniform(0.7, 1.3)
370
+ arr = arr * gain
371
+
372
+ # Add background noise
373
+ if random.random() < 0.3:
374
+ noise_level = random.uniform(0.001, 0.01)
375
+ arr = arr + np.random.randn(len(arr)).astype(np.float32) * noise_level
376
+
377
+ # Time shift
378
+ if random.random() < 0.3:
379
+ shift = random.randint(-int(0.1 * len(arr)), int(0.1 * len(arr)))
380
+ arr = np.roll(arr, shift)
381
+
382
+ return np.clip(arr, -1, 1).astype(np.float32)
383
+
384
+ def _augment_image(self, img):
385
+ """Image augmentation: random crop, flip, brightness/contrast jitter."""
386
+ # Random horizontal flip
387
+ if random.random() < 0.5:
388
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
389
+
390
+ # Random brightness variation
391
+ if random.random() < 0.3:
392
+ from PIL import ImageEnhance
393
+ factor = random.uniform(0.8, 1.2)
394
+ img = ImageEnhance.Brightness(img).enhance(factor)
395
+
396
+ # Random contrast variation
397
+ if random.random() < 0.3:
398
+ from PIL import ImageEnhance
399
+ factor = random.uniform(0.8, 1.2)
400
+ img = ImageEnhance.Contrast(img).enhance(factor)
401
+
402
+ return img
403
+
404
+
405
+ # ============================================================================
406
+ # Collate function — same interface as dataset_real.py
407
+ # ============================================================================
408
+
409
+ def multimodal_collate_fn(batch):
410
+ """
411
+ Collate function that handles variable-length audio.
412
+ Pads audio to the max length in the batch.
413
+ """
414
+ pixel_values = torch.stack([b["pixel_values"] for b in batch])
415
+ labels = torch.stack([b["labels"] for b in batch])
416
+
417
+ audio_list = [b["audio_values"] for b in batch]
418
+ max_len = max(a.shape[-1] for a in audio_list)
419
+
420
+ padded_audio = []
421
+ for a in audio_list:
422
+ if a.shape[-1] < max_len:
423
+ pad_size = max_len - a.shape[-1]
424
+ a = F.pad(a, (0, pad_size))
425
+ padded_audio.append(a)
426
+
427
+ audio_values = torch.stack(padded_audio)
428
+
429
+ return {
430
+ "pixel_values": pixel_values,
431
+ "audio_values": audio_values,
432
+ "labels": labels,
433
+ }