| """ |
| ESC-50 dataset utilities for loading and sampling audio data. |
| """ |
|
|
| import csv |
| import json |
| import random |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import pandas as pd |
|
|
| from .logger import setup_logger |
|
|
| logger = setup_logger(__name__) |
|
|
|
|
| def load_or_create_class_subset(config: dict, all_categories: List[str]) -> List[str]: |
| """ |
| Load persisted class subset or create a new one. |
| |
| Args: |
| config: Configuration dictionary with dataset.use_class_subset, etc. |
| all_categories: List of all available categories |
| |
| Returns: |
| List of category names to use (either subset or all) |
| """ |
| dataset_config = config.get('dataset', {}) |
| use_subset = dataset_config.get('use_class_subset', False) |
| |
| if not use_subset: |
| logger.info(f"Using all {len(all_categories)} classes") |
| return all_categories |
| |
| num_classes = dataset_config.get('num_classes_subset', len(all_categories)) |
| persist_path = Path(dataset_config.get('subset_persist_path', 'class_subset.json')) |
| subset_seed = dataset_config.get('subset_seed', 42) |
| |
| |
| if persist_path.exists(): |
| try: |
| with open(persist_path, 'r') as f: |
| data = json.load(f) |
| subset = data.get('classes', []) |
| |
| |
| if len(subset) == num_classes and all(c in all_categories for c in subset): |
| logger.info(f"Loaded persisted class subset from {persist_path}: {len(subset)} classes") |
| return subset |
| else: |
| logger.warning(f"Invalid persisted subset, regenerating...") |
| except Exception as e: |
| logger.warning(f"Failed to load persisted subset: {e}, regenerating...") |
| |
| |
| random.seed(subset_seed) |
| subset = random.sample(all_categories, min(num_classes, len(all_categories))) |
| subset.sort() |
| |
| |
| persist_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(persist_path, 'w') as f: |
| json.dump({ |
| 'classes': subset, |
| 'num_classes': len(subset), |
| 'seed': subset_seed, |
| 'total_available': len(all_categories) |
| }, f, indent=2) |
| |
| logger.info(f"Created and persisted new class subset: {len(subset)} classes to {persist_path}") |
| return subset |
|
|
|
|
| class ESC50Dataset: |
| """Handler for ESC-50 dataset.""" |
| |
| |
| ALL_CATEGORIES = [ |
| 'dog', 'chirping_birds', 'vacuum_cleaner', 'thunderstorm', 'door_wood_knock', |
| 'can_opening', 'crow', 'clapping', 'fireworks', 'chainsaw', 'airplane', |
| 'mouse_click', 'pouring_water', 'train', 'sheep', 'water_drops', 'church_bells', |
| 'clock_alarm', 'keyboard_typing', 'wind', 'footsteps', 'frog', 'cow', |
| 'brushing_teeth', 'car_horn', 'crackling_fire', 'helicopter', 'drinking_sipping', |
| 'rain', 'insects', 'laughing', 'hen', 'engine', 'breathing', 'crying_baby', |
| 'hand_saw', 'coughing', 'glass_breaking', 'snoring', 'toilet_flush', 'pig', |
| 'washing_machine', 'clock_tick', 'sneezing', 'rooster', 'sea_waves', 'siren', |
| 'cat', 'door_wood_creaks', 'crickets' |
| ] |
| |
| def __init__(self, metadata_path: str, audio_path: str, config: Optional[dict] = None): |
| """ |
| Initialize ESC-50 dataset handler. |
| |
| Args: |
| metadata_path: Path to esc50.csv metadata file |
| audio_path: Path to audio directory |
| config: Optional configuration dict with dataset.use_class_subset settings |
| """ |
| self.metadata_path = Path(metadata_path) |
| self.audio_path = Path(audio_path) |
| self.config = config or {} |
| self.df = None |
| self.category_to_target = {} |
| self.target_to_category = {} |
| |
| |
| self.CATEGORIES = load_or_create_class_subset(self.config, self.ALL_CATEGORIES) |
| self.category_usage_counts = {cat: 0 for cat in self.CATEGORIES} |
| |
| self.load_metadata() |
| |
| def load_metadata(self): |
| """Load ESC-50 metadata CSV.""" |
| try: |
| self.df = pd.read_csv(self.metadata_path) |
| logger.info(f"Loaded ESC-50 metadata: {len(self.df)} files") |
| |
| |
| for target, category in zip(self.df['target'], self.df['category']): |
| self.category_to_target[category] = target |
| self.target_to_category[target] = category |
| |
| logger.info(f"Found {len(self.category_to_target)} unique categories") |
| except Exception as e: |
| logger.error(f"Error loading metadata: {e}") |
| raise |
| |
| def get_files_by_category(self, category: str) -> List[str]: |
| """ |
| Get all audio files for a specific category. |
| |
| Args: |
| category: Sound category name |
| |
| Returns: |
| List of filenames for the category |
| """ |
| if category not in self.category_to_target: |
| raise ValueError(f"Unknown category: {category}") |
| |
| target = self.category_to_target[category] |
| files = self.df[self.df['target'] == target]['filename'].tolist() |
| return files |
| |
| def get_files_by_target(self, target: int) -> List[str]: |
| """ |
| Get all audio files for a specific target ID. |
| |
| Args: |
| target: Target class ID (0-49) |
| |
| Returns: |
| List of filenames for the target |
| """ |
| files = self.df[self.df['target'] == target]['filename'].tolist() |
| return files |
| |
| def sample_categories(self, n: int, exclude: Optional[List[str]] = None) -> List[str]: |
| """ |
| Sample n unique random categories from the active subset. |
| |
| Args: |
| n: Number of categories to sample |
| exclude: Optional list of categories to exclude |
| |
| Returns: |
| List of sampled category names |
| """ |
| available = [c for c in self.CATEGORIES if c not in (exclude or [])] |
| if n > len(available): |
| raise ValueError(f"Cannot sample {n} categories from subset, only {len(available)} available (subset size: {len(self.CATEGORIES)})") |
| return random.sample(available, n) |
| |
| def sample_targets(self, n: int, exclude: Optional[List[int]] = None) -> List[int]: |
| """ |
| Sample n unique random targets from the active subset. |
| |
| Args: |
| n: Number of targets to sample |
| exclude: Optional list of targets to exclude |
| |
| Returns: |
| List of sampled target IDs corresponding to categories in the subset |
| """ |
| |
| available_targets = [self.category_to_target[cat] for cat in self.CATEGORIES] |
| available = [t for t in available_targets if t not in (exclude or [])] |
| if n > len(available): |
| raise ValueError(f"Cannot sample {n} targets from subset, only {len(available)} available (subset size: {len(self.CATEGORIES)})") |
| return random.sample(available, n) |
| |
| def sample_file_from_category(self, category: str) -> Tuple[str, str]: |
| """ |
| Sample a random audio file from a category. |
| |
| Args: |
| category: Sound category name |
| |
| Returns: |
| Tuple of (filename, full_path) |
| """ |
| files = self.get_files_by_category(category) |
| filename = random.choice(files) |
| full_path = str(self.audio_path / filename) |
| return filename, full_path |
| |
| def sample_file_from_target(self, target: int) -> Tuple[str, str, str]: |
| """ |
| Sample a random audio file from a target. |
| |
| Args: |
| target: Target class ID |
| |
| Returns: |
| Tuple of (filename, category, full_path) |
| """ |
| files = self.get_files_by_target(target) |
| filename = random.choice(files) |
| category = self.target_to_category[target] |
| full_path = str(self.audio_path / filename) |
| return filename, category, full_path |
| |
| def get_category_from_filename(self, filename: str) -> str: |
| """Get category name from filename.""" |
| row = self.df[self.df['filename'] == filename] |
| if len(row) == 0: |
| raise ValueError(f"Unknown filename: {filename}") |
| return row.iloc[0]['category'] |
| |
| def get_file_path(self, filename: str) -> str: |
| """Get full path for a filename.""" |
| return str(self.audio_path / filename) |
| |
| def sample_categories_balanced(self, n: int, exclude: Optional[List[str]] = None, |
| answer_category: Optional[str] = None) -> List[str]: |
| """ |
| Sample n unique categories with balanced usage tracking. |
| |
| This method ensures that over many samples, all categories appear |
| roughly equally as answers by preferentially sampling underused categories. |
| |
| Args: |
| n: Number of categories to sample |
| exclude: Optional list of categories to exclude |
| answer_category: If provided, ensures this category is included and tracks it |
| |
| Returns: |
| List of sampled category names with answer_category first if provided |
| """ |
| available = [c for c in self.CATEGORIES if c not in (exclude or [])] |
| if n > len(available): |
| raise ValueError(f"Cannot sample {n} categories, only {len(available)} available") |
| |
| if answer_category: |
| |
| self.category_usage_counts[answer_category] += 1 |
| |
| |
| available = [c for c in available if c != answer_category] |
| other_categories = random.sample(available, n - 1) |
| return [answer_category] + other_categories |
| else: |
| |
| return random.sample(available, n) |
| |
| def get_least_used_categories(self, n: int, exclude: Optional[List[str]] = None) -> List[str]: |
| """ |
| Get n categories that have been used least as answers. |
| |
| Args: |
| n: Number of categories to get |
| exclude: Optional list of categories to exclude |
| |
| Returns: |
| List of least-used category names |
| """ |
| available = [c for c in self.CATEGORIES if c not in (exclude or [])] |
| if n > len(available): |
| raise ValueError(f"Cannot get {n} categories, only {len(available)} available") |
| |
| |
| sorted_categories = sorted(available, key=lambda c: self.category_usage_counts[c]) |
| |
| |
| min_count = self.category_usage_counts[sorted_categories[0]] |
| candidates = [c for c in sorted_categories if self.category_usage_counts[c] == min_count] |
| |
| if len(candidates) >= n: |
| |
| return random.sample(candidates, n) |
| else: |
| |
| result = candidates.copy() |
| remaining = n - len(result) |
| next_tier = [c for c in sorted_categories if c not in candidates][:remaining] |
| result.extend(next_tier) |
| return result |
| |
| def get_category_usage_stats(self) -> Dict[str, int]: |
| """Get current category usage statistics.""" |
| return self.category_usage_counts.copy() |
| |
| def reset_category_usage(self): |
| """Reset category usage tracking.""" |
| self.category_usage_counts = {cat: 0 for cat in self.CATEGORIES} |
| logger.info("Reset category usage tracking") |
|
|
|
|
| class PreprocessedESC50Dataset(ESC50Dataset): |
| """ |
| Handler for preprocessed ESC-50 dataset with effective durations. |
| |
| Extends ESC50Dataset to use trimmed audio files and effective duration |
| metadata from amplitude-based preprocessing. |
| """ |
| |
| def __init__( |
| self, |
| metadata_path: str, |
| audio_path: str, |
| preprocessed_path: str, |
| config: Optional[dict] = None |
| ): |
| """ |
| Initialize preprocessed ESC-50 dataset handler. |
| |
| Args: |
| metadata_path: Path to original esc50.csv metadata file |
| audio_path: Path to original audio directory (fallback) |
| preprocessed_path: Path to preprocessed data directory |
| config: Optional configuration dict with dataset.use_class_subset settings |
| """ |
| super().__init__(metadata_path, audio_path, config) |
| |
| self.preprocessed_path = Path(preprocessed_path) |
| self.trimmed_audio_path = self.preprocessed_path / "trimmed_audio" |
| self.effective_durations_path = self.preprocessed_path / "effective_durations.csv" |
| |
| |
| self.effective_df = None |
| self.load_effective_durations() |
| |
| def load_effective_durations(self): |
| """Load effective durations from preprocessed CSV.""" |
| try: |
| self.effective_df = pd.read_csv(self.effective_durations_path) |
| logger.info(f"Loaded effective durations for {len(self.effective_df)} clips") |
| |
| |
| self.filename_to_effective = dict( |
| zip(self.effective_df['filename'], self.effective_df['effective_duration_s']) |
| ) |
| self.filename_to_category = dict( |
| zip(self.effective_df['filename'], self.effective_df['category']) |
| ) |
| |
| |
| self.category_effective_stats = self.effective_df.groupby('category').agg({ |
| 'effective_duration_s': ['mean', 'std', 'min', 'max', 'count'] |
| }).round(4) |
| self.category_effective_stats.columns = ['mean', 'std', 'min', 'max', 'count'] |
| |
| logger.info("Created effective duration lookup tables") |
| |
| except Exception as e: |
| logger.error(f"Error loading effective durations: {e}") |
| raise |
| |
| def get_effective_duration(self, filename: str) -> float: |
| """ |
| Get effective duration for a specific file. |
| |
| Args: |
| filename: Audio filename |
| |
| Returns: |
| Effective duration in seconds |
| """ |
| if filename not in self.filename_to_effective: |
| logger.warning(f"No effective duration for {filename}, using default 5.0s") |
| return 5.0 |
| return self.filename_to_effective[filename] |
| |
| def get_category_effective_stats(self, category: str) -> Dict: |
| """ |
| Get effective duration statistics for a category. |
| |
| Args: |
| category: Category name |
| |
| Returns: |
| Dict with mean, std, min, max, count |
| """ |
| if category not in self.category_effective_stats.index: |
| return {'mean': 5.0, 'std': 0.0, 'min': 5.0, 'max': 5.0, 'count': 0} |
| |
| stats = self.category_effective_stats.loc[category] |
| return { |
| 'mean': stats['mean'], |
| 'std': stats['std'], |
| 'min': stats['min'], |
| 'max': stats['max'], |
| 'count': int(stats['count']) |
| } |
| |
| def get_files_by_category_with_durations(self, category: str) -> List[Dict]: |
| """ |
| Get all files for a category with their effective durations. |
| |
| Args: |
| category: Category name |
| |
| Returns: |
| List of dicts with filename, effective_duration_s, filepath |
| """ |
| cat_df = self.effective_df[self.effective_df['category'] == category] |
| |
| results = [] |
| for _, row in cat_df.iterrows(): |
| results.append({ |
| 'filename': row['filename'], |
| 'effective_duration_s': row['effective_duration_s'], |
| 'filepath': str(self.trimmed_audio_path / row['filename']), |
| 'raw_duration_s': row['raw_duration_s'], |
| 'peak_amplitude_db': row['peak_amplitude_db'] |
| }) |
| |
| return results |
| |
| def sample_file_from_category_with_duration( |
| self, |
| category: str, |
| min_effective_duration: float = None, |
| max_effective_duration: float = None |
| ) -> Tuple[str, str, float]: |
| """ |
| Sample a file from category with optional duration constraints. |
| |
| Args: |
| category: Category name |
| min_effective_duration: Minimum effective duration (optional) |
| max_effective_duration: Maximum effective duration (optional) |
| |
| Returns: |
| Tuple of (filename, filepath, effective_duration_s) |
| """ |
| files = self.get_files_by_category_with_durations(category) |
| |
| |
| if min_effective_duration is not None: |
| files = [f for f in files if f['effective_duration_s'] >= min_effective_duration] |
| if max_effective_duration is not None: |
| files = [f for f in files if f['effective_duration_s'] <= max_effective_duration] |
| |
| if not files: |
| |
| logger.warning(f"No files match duration constraints for {category}, using any file") |
| files = self.get_files_by_category_with_durations(category) |
| |
| selected = random.choice(files) |
| return selected['filename'], selected['filepath'], selected['effective_duration_s'] |
| |
| def sample_files_from_category_to_reach_duration( |
| self, |
| category: str, |
| target_duration_s: float, |
| prefer_same_file: bool = True |
| ) -> Tuple[List[str], List[str], float]: |
| """ |
| Sample files from a category to reach a target total effective duration. |
| |
| Args: |
| category: Category name |
| target_duration_s: Target total effective duration |
| prefer_same_file: If True, try repeating same file first |
| |
| Returns: |
| Tuple of (filenames_list, filepaths_list, actual_total_duration_s) |
| """ |
| files = self.get_files_by_category_with_durations(category) |
| |
| if not files: |
| raise ValueError(f"No files found for category: {category}") |
| |
| selected_filenames = [] |
| selected_filepaths = [] |
| total_duration = 0.0 |
| |
| if prefer_same_file: |
| |
| files_sorted = sorted(files, key=lambda x: x['effective_duration_s'], reverse=True) |
| selected_file = files_sorted[0] |
| |
| |
| reps_needed = max(1, int(target_duration_s / selected_file['effective_duration_s']) + 1) |
| |
| for _ in range(reps_needed): |
| selected_filenames.append(selected_file['filename']) |
| selected_filepaths.append(selected_file['filepath']) |
| total_duration += selected_file['effective_duration_s'] |
| |
| if total_duration >= target_duration_s: |
| break |
| else: |
| |
| random.shuffle(files) |
| file_idx = 0 |
| |
| while total_duration < target_duration_s: |
| selected_file = files[file_idx % len(files)] |
| selected_filenames.append(selected_file['filename']) |
| selected_filepaths.append(selected_file['filepath']) |
| total_duration += selected_file['effective_duration_s'] |
| file_idx += 1 |
| |
| |
| if file_idx > 100: |
| logger.warning(f"Hit safety limit when sampling files for {category}") |
| break |
| |
| return selected_filenames, selected_filepaths, total_duration |
| |
| def get_categories_sorted_by_effective_duration(self, ascending: bool = True) -> List[str]: |
| """ |
| Get categories sorted by their mean effective duration. |
| |
| Args: |
| ascending: If True, shortest first; if False, longest first |
| |
| Returns: |
| List of category names sorted by mean effective duration |
| """ |
| sorted_stats = self.category_effective_stats.sort_values('mean', ascending=ascending) |
| return sorted_stats.index.tolist() |
|
|
|
|