Ellaft's picture
Add dataset_v2.py: adapter for build_dataset.py output, drop-in replacement for dataset_real.py
c29559f verified
"""
Dataset Loader v2 — Loads data built by build_dataset.py
==========================================================
Drop-in replacement for dataset_real.py. Loads from either:
1. Local manifest (dataset_build/dataset_manifest.json) — from build_dataset.py
2. HuggingFace Hub dataset (Ellaft/pc-fault-real-dataset) — if uploaded
Data sources: YouTube scraped audio/frames, HF cooling-fan recordings,
synthetic BIOS beep codes, synthetic HDD clicks, synthetic BSOD/POST/thermal images.
Usage — just change one import in train_v2.py:
from dataset_v2 import BuiltDataset as PCFaultDataset, multimodal_collate_fn
Or run train_v2.py with --dataset flag:
python train_v2.py --dataset local --dataset_dir ./dataset_build
python train_v2.py --dataset hub --hub_dataset Ellaft/pc-fault-real-dataset
"""
import os, json, random, glob
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
from collections import Counter
from typing import Optional
try:
import torchaudio.transforms as T
HAS_TORCHAUDIO = True
except ImportError:
HAS_TORCHAUDIO = False
try:
import soundfile as sf
HAS_SOUNDFILE = True
except ImportError:
HAS_SOUNDFILE = False
try:
import librosa
HAS_LIBROSA = True
except ImportError:
HAS_LIBROSA = False
try:
from config import FAULT_CLASSES, DataConfig, ModelConfig
except ImportError:
# Standalone mode — define fault classes inline
FAULT_CLASSES = [
"normal_operation", "boot_failure", "overheating_fan",
"storage_failure", "system_crash",
]
DataConfig = None
ModelConfig = None
# ============================================================================
# Audio loading helpers
# ============================================================================
def load_audio_file(path, target_sr=16000):
"""Load a WAV file and return (numpy_array, sample_rate)."""
if HAS_SOUNDFILE:
arr, sr = sf.read(path, dtype="float32")
if arr.ndim > 1:
arr = arr.mean(axis=1) # mono
return arr, sr
elif HAS_LIBROSA:
arr, sr = librosa.load(path, sr=target_sr, mono=True)
return arr, sr
elif HAS_TORCHAUDIO:
import torchaudio
waveform, sr = torchaudio.load(path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
return waveform.squeeze(0).numpy(), sr
else:
raise ImportError("Need soundfile, librosa, or torchaudio to load audio. "
"Install: pip install soundfile")
def resample_audio(arr, orig_sr, target_sr=16000):
"""Resample audio array to target sample rate."""
if orig_sr == target_sr:
return arr
if HAS_TORCHAUDIO:
resampler = T.Resample(orig_sr, target_sr)
tensor = torch.tensor(arr, dtype=torch.float32).unsqueeze(0)
return resampler(tensor).squeeze(0).numpy()
elif HAS_LIBROSA:
return librosa.resample(arr, orig_sr=orig_sr, target_sr=target_sr)
else:
# Simple linear interpolation fallback
ratio = target_sr / orig_sr
new_len = int(len(arr) * ratio)
indices = np.linspace(0, len(arr) - 1, new_len)
return np.interp(indices, np.arange(len(arr)), arr).astype(np.float32)
# ============================================================================
# Main Dataset Class
# ============================================================================
class BuiltDataset(Dataset):
"""
Loads multimodal PC fault dataset from build_dataset.py output.
Matches the exact interface of RealPCFaultDataset so train_v2.py works
without any changes — just swap the import.
Supports two modes:
- "local": Load from manifest JSON + local files (default)
- "hub": Load from HuggingFace Hub dataset
"""
def __init__(self, config, model_config, split="train",
vit_processor=None, ast_feature_extractor=None,
augment=True, val_ratio=0.15, test_ratio=0.15, seed=42,
# New parameters for v2 dataset
source="local", # "local" or "hub"
dataset_dir="./dataset_build",
hub_dataset="Ellaft/pc-fault-real-dataset"):
"""
Args:
config: DataConfig instance
model_config: ModelConfig instance (unused, kept for compat)
split: "train", "val", or "test"
vit_processor: ViT image processor
ast_feature_extractor: AST feature extractor
augment: Whether to apply data augmentation (train only)
val_ratio: Validation split ratio (for local mode)
test_ratio: Test split ratio (for local mode)
seed: Random seed for reproducibility
source: "local" (manifest files) or "hub" (HF dataset)
dataset_dir: Path to build_dataset.py output (local mode)
hub_dataset: HuggingFace dataset ID (hub mode)
"""
self.config = config
self.split = split
self.augment = augment and (split == "train")
self.vit_processor = vit_processor
self.ast_feature_extractor = ast_feature_extractor
self.target_sr = 16000 # AST expects 16kHz
self.audio_duration = config.audio_duration # seconds
self.target_audio_len = int(self.target_sr * self.audio_duration)
if source == "hub":
self._load_from_hub(hub_dataset, split, seed)
else:
self._load_from_local(dataset_dir, split, val_ratio, test_ratio, seed)
# Print statistics
lc = Counter(s["fault_label"] for s in self.samples)
n_has_audio = sum(1 for s in self.samples if s.get("audio_path") or s.get("audio_data") is not None)
n_has_image = sum(1 for s in self.samples if s.get("image_path") or s.get("image_data") is not None)
print(f"\n[BuiltDataset] {split}: {len(self.samples)} samples "
f"(audio: {n_has_audio}, images: {n_has_image})")
for label_id in range(5):
print(f" {FAULT_CLASSES[label_id]}: {lc.get(label_id, 0)}")
def _load_from_local(self, dataset_dir, split, val_ratio, test_ratio, seed):
"""Load from build_dataset.py manifest."""
dataset_dir = Path(dataset_dir)
manifest_path = dataset_dir / "dataset_manifest.json"
if not manifest_path.exists():
raise FileNotFoundError(
f"Dataset manifest not found at {manifest_path}\n"
f"Run build_dataset.py first:\n"
f" cd data && python build_dataset.py --max_per_class 300")
print(f"[BuiltDataset] Loading from {manifest_path}")
with open(manifest_path) as f:
manifest = json.load(f)
all_samples = manifest["samples"]
print(f" Total samples in manifest: {len(all_samples)}")
# Convert manifest format to our internal format
samples = []
for s in all_samples:
samples.append({
"fault_label": s["fault_class"],
"audio_path": s.get("audio_path"),
"image_path": s.get("image_path"),
})
# Stratified split
rng = random.Random(seed)
by_class = {i: [] for i in range(5)}
for s in samples:
by_class[s["fault_label"]].append(s)
train_samples, val_samples, test_samples = [], [], []
for cls_id, cls_samples in by_class.items():
rng.shuffle(cls_samples)
n = len(cls_samples)
n_test = max(1, int(n * test_ratio))
n_val = max(1, int(n * val_ratio))
n_train = n - n_val - n_test
test_samples.extend(cls_samples[:n_test])
val_samples.extend(cls_samples[n_test:n_test + n_val])
train_samples.extend(cls_samples[n_test + n_val:])
if split == "train":
self.samples = train_samples
elif split in ("val", "validation"):
self.samples = val_samples
else:
self.samples = test_samples
rng.shuffle(self.samples)
def _load_from_hub(self, hub_dataset, split, seed):
"""Load from HuggingFace Hub dataset."""
from datasets import load_dataset
# Map our split names to Hub split names
hub_split = {"val": "validation", "validation": "validation",
"train": "train", "test": "test"}.get(split, split)
print(f"[BuiltDataset] Loading from Hub: {hub_dataset} (split={hub_split})")
ds = load_dataset(hub_dataset, split=hub_split)
print(f" Loaded {len(ds)} samples")
self.hub_data = ds
self.samples = []
for i in range(len(ds)):
self.samples.append({
"fault_label": ds[i]["fault_class"],
"hub_idx": i,
# Audio/image are loaded lazily from Hub dataset
"audio_data": ds[i].get("audio"),
"image_data": ds[i].get("image"),
})
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
s = self.samples[idx]
fault_label = s["fault_label"]
# ---- Load Audio ----
audio_values = self._load_audio(s)
# ---- Load Image ----
pixel_values = self._load_image(s)
return {
"pixel_values": pixel_values,
"audio_values": audio_values,
"labels": torch.tensor(fault_label, dtype=torch.long),
}
def _load_audio(self, sample):
"""Load and process audio into AST-compatible format."""
arr = None
sr = self.target_sr
# Try Hub data first
if "audio_data" in sample and sample["audio_data"] is not None:
audio_data = sample["audio_data"]
if isinstance(audio_data, dict):
arr = np.array(audio_data["array"], dtype=np.float32)
sr = audio_data.get("sampling_rate", self.target_sr)
elif isinstance(audio_data, np.ndarray):
arr = audio_data.astype(np.float32)
# Try local file
elif sample.get("audio_path") and os.path.exists(sample["audio_path"]):
try:
arr, sr = load_audio_file(sample["audio_path"], self.target_sr)
except Exception as e:
print(f" âš  Failed to load audio {sample['audio_path']}: {e}")
arr = None
# Fallback: generate silence (model still gets image)
if arr is None:
arr = np.zeros(self.target_audio_len, dtype=np.float32)
sr = self.target_sr
# Ensure float32
arr = arr.astype(np.float32)
# Resample to 16kHz for AST
if sr != self.target_sr:
arr = resample_audio(arr, sr, self.target_sr)
# Pad/trim to target duration
if len(arr) < self.target_audio_len:
arr = np.pad(arr, (0, self.target_audio_len - len(arr)))
elif len(arr) > self.target_audio_len:
# Random crop during training, center crop during eval
if self.augment:
start = random.randint(0, len(arr) - self.target_audio_len)
else:
start = (len(arr) - self.target_audio_len) // 2
arr = arr[start:start + self.target_audio_len]
# Data augmentation (training only)
if self.augment:
arr = self._augment_audio(arr)
# Process with AST feature extractor
if self.ast_feature_extractor:
inputs = self.ast_feature_extractor(
arr, sampling_rate=self.target_sr,
return_tensors="pt")
audio_values = inputs["input_values"].squeeze(0)
else:
# Fallback: raw waveform tensor
audio_values = torch.tensor(arr, dtype=torch.float32)
return audio_values
def _load_image(self, sample):
"""Load and process image into ViT-compatible format."""
img = None
# Try Hub data first
if "image_data" in sample and sample["image_data"] is not None:
img = sample["image_data"]
if not isinstance(img, Image.Image):
try:
img = Image.fromarray(np.array(img))
except Exception:
img = None
# Try local file
elif sample.get("image_path") and os.path.exists(sample["image_path"]):
try:
img = Image.open(sample["image_path"])
except Exception as e:
print(f" âš  Failed to load image {sample['image_path']}: {e}")
img = None
# Fallback: black image
if img is None:
img = Image.new("RGB", (224, 224), color=(0, 0, 0))
# Ensure RGB
if img.mode != "RGB":
img = img.convert("RGB")
# Data augmentation (training only)
if self.augment:
img = self._augment_image(img)
# Process with ViT processor
if self.vit_processor:
pixel_values = self.vit_processor(
images=img, return_tensors="pt")["pixel_values"].squeeze(0)
else:
# Manual normalization fallback
arr = np.array(img.resize((224, 224))).astype(np.float32) / 255.0
arr = (arr - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
pixel_values = torch.tensor(arr, dtype=torch.float32).permute(2, 0, 1)
return pixel_values
def _augment_audio(self, arr):
"""Audio augmentation: noise injection, time shift, gain variation."""
# Random gain
if random.random() < 0.5:
gain = random.uniform(0.7, 1.3)
arr = arr * gain
# Add background noise
if random.random() < 0.3:
noise_level = random.uniform(0.001, 0.01)
arr = arr + np.random.randn(len(arr)).astype(np.float32) * noise_level
# Time shift
if random.random() < 0.3:
shift = random.randint(-int(0.1 * len(arr)), int(0.1 * len(arr)))
arr = np.roll(arr, shift)
return np.clip(arr, -1, 1).astype(np.float32)
def _augment_image(self, img):
"""Image augmentation: random crop, flip, brightness/contrast jitter."""
# Random horizontal flip
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# Random brightness variation
if random.random() < 0.3:
from PIL import ImageEnhance
factor = random.uniform(0.8, 1.2)
img = ImageEnhance.Brightness(img).enhance(factor)
# Random contrast variation
if random.random() < 0.3:
from PIL import ImageEnhance
factor = random.uniform(0.8, 1.2)
img = ImageEnhance.Contrast(img).enhance(factor)
return img
# ============================================================================
# Collate function — same interface as dataset_real.py
# ============================================================================
def multimodal_collate_fn(batch):
"""
Collate function that handles variable-length audio.
Pads audio to the max length in the batch.
"""
pixel_values = torch.stack([b["pixel_values"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
audio_list = [b["audio_values"] for b in batch]
max_len = max(a.shape[-1] for a in audio_list)
padded_audio = []
for a in audio_list:
if a.shape[-1] < max_len:
pad_size = max_len - a.shape[-1]
a = F.pad(a, (0, pad_size))
padded_audio.append(a)
audio_values = torch.stack(padded_audio)
return {
"pixel_values": pixel_values,
"audio_values": audio_values,
"labels": labels,
}