| """ |
| This module provides PyTorch Dataset implementations for hierarchical VCF data |
| """ |
|
|
| import torch |
| import json |
| import pickle |
| import logging |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Optional, Union, Any, Callable |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
| import pandas as pd |
|
|
| from datasets import Dataset as HFDataset, DatasetDict |
| from transformers import PreTrainedTokenizer |
|
|
| from config import DataConfig, ModelConfig, ConfigManager |
| from parser import VCFParser, MutationRecord |
| from tokenizer import HierarchicalVCFTokenizer, HierarchicalDataCollator |
|
|
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class HierarchicalVCFDataset(Dataset): |
| |
| def __init__(self, |
| data_source: Union[str, Path, Dict, List], |
| tokenizer: HierarchicalVCFTokenizer, |
| config: Optional[DataConfig] = None, |
| labels: Optional[Union[List, np.ndarray]] = None, |
| transform: Optional[Callable] = None, |
| target_transform: Optional[Callable] = None, |
| cache_processed_data: bool = True): |
| """ |
| Initialize the Hierarchical VCF Dataset. |
| Args: |
| data_source: Path to data file, or preprocessed data dict/list |
| tokenizer: Tokenizer for encoding mutations |
| config: Data configuration |
| labels: Optional labels for supervised learning |
| transform: Optional transform to apply to samples |
| target_transform: Optional transform to apply to labels |
| cache_processed_data: Whether to cache processed data |
| """ |
| |
| self.config = config or DataConfig() |
| self.tokenizer = tokenizer |
| self.labels = labels |
| self.transform = transform |
| self.target_transform = target_transform |
| self.cache_processed_data = cache_processed_data |
| |
| |
| self.raw_data = self._load_data(data_source) |
| self.processed_data = self._process_data() |
| |
| |
| self._validate_data() |
| |
| |
| self.stats = self._compute_statistics() |
| |
| logger.info(f"Dataset initialized with {len(self.processed_data)} samples") |
| logger.info(f"Dataset statistics: {self.stats}") |
| |
| def _load_data(self, data_source: Union[str, Path, Dict, List]) -> Dict[str, Any]: |
| |
| if isinstance(data_source, (dict, list)): |
| |
| if isinstance(data_source, list): |
| |
| return {f"sample_{i}": sample for i, sample in enumerate(data_source)} |
| return data_source |
| |
| |
| data_path = Path(data_source) |
| |
| if not data_path.exists(): |
| raise FileNotFoundError(f"Data file not found: {data_path}") |
| |
| try: |
| if data_path.suffix.lower() == '.json': |
| with open(data_path, 'r') as f: |
| return json.load(f) |
| |
| elif data_path.suffix.lower() == '.pkl': |
| with open(data_path, 'rb') as f: |
| return pickle.load(f) |
| |
| elif data_path.suffix.lower() == '.vcf': |
| |
| parser = VCFParser(config=self.config) |
| return parser.parse_vcf_file(data_path) |
| |
| else: |
| raise ValueError(f"Unsupported file format: {data_path.suffix}") |
| |
| except Exception as e: |
| logger.error(f"Error loading data from {data_path}: {e}") |
| raise |
| |
| def _process_data(self) -> List[Dict[str, Any]]: |
| """Raw hierarchical data into dataset format.""" |
| |
| processed_samples = [] |
| |
| for sample_id, sample_data in self.raw_data.items(): |
| try: |
| |
| standardized_sample = self._standardize_sample_format(sample_data) |
| |
| |
| if self._should_include_sample(standardized_sample): |
| |
| encoded_sample = self.tokenizer.encode_hierarchical_sample(standardized_sample) |
| |
| processed_sample = { |
| 'sample_id': sample_id, |
| 'encoded_data': encoded_sample, |
| 'raw_data': standardized_sample if not self.cache_processed_data else None |
| } |
| |
| processed_samples.append(processed_sample) |
| |
| except Exception as e: |
| logger.warning(f"Error processing sample {sample_id}: {e}") |
| continue |
| |
| return processed_samples |
| |
| def _standardize_sample_format(self, sample_data: Dict[str, Any]) -> Dict[str, Any]: |
| |
| |
| if 'mutations' in sample_data: |
| |
| return self._convert_flat_to_hierarchical(sample_data['mutations']) |
| |
| elif isinstance(sample_data, dict) and all( |
| isinstance(v, dict) for v in sample_data.values() |
| ): |
| |
| return sample_data |
| |
| else: |
| |
| return self._convert_flat_to_hierarchical(sample_data) |
| |
| def _convert_flat_to_hierarchical(self, mutations: List[Dict]) -> Dict[str, Any]: |
| """Convert flat mutation list to hierarchical format.""" |
| |
| hierarchical = {} |
| |
| for mutation in mutations: |
| |
| pathway = mutation.get('pathway', 'Unknown_Pathway') |
| chromosome = mutation.get('chromosome', mutation.get('chrom', 'Unknown')) |
| gene = mutation.get('gene', mutation.get('gene_id', 'Unknown_Gene')) |
| |
| |
| if pathway not in hierarchical: |
| hierarchical[pathway] = {} |
| if chromosome not in hierarchical[pathway]: |
| hierarchical[pathway][chromosome] = {} |
| if gene not in hierarchical[pathway][chromosome]: |
| hierarchical[pathway][chromosome][gene] = [] |
| |
| |
| hierarchical[pathway][chromosome][gene].append(mutation) |
| |
| return hierarchical |
| |
| def _should_include_sample(self, sample_data: Dict[str, Any]) -> bool: |
| """Determine if sample should be included based on filtering criteria.""" |
| |
| |
| total_mutations = 0 |
| for pathway_data in sample_data.values(): |
| for chrom_data in pathway_data.values(): |
| for gene_mutations in chrom_data.values(): |
| total_mutations += len(gene_mutations) |
| |
| |
| if total_mutations < self.config.min_mutations_per_sample: |
| return False |
| |
| if total_mutations > self.config.max_mutations_per_sample: |
| return False |
| |
| return True |
| |
| def _validate_data(self) -> None: |
| |
| if len(self.processed_data) == 0: |
| raise ValueError("No valid samples found in dataset") |
| |
| if self.labels is not None: |
| if len(self.labels) != len(self.processed_data): |
| raise ValueError( |
| f"Number of labels ({len(self.labels)}) doesn't match " |
| f"number of samples ({len(self.processed_data)})" |
| ) |
| |
| def _compute_statistics(self) -> Dict[str, Any]: |
| """CDataset statistics.""" |
| |
| stats = { |
| 'num_samples': len(self.processed_data), |
| 'num_pathways': set(), |
| 'num_chromosomes': set(), |
| 'num_genes': set(), |
| 'mutations_per_sample': [], |
| 'genes_per_sample': [], |
| 'pathways_per_sample': [] |
| } |
| |
| for sample in self.processed_data: |
| encoded_data = sample['encoded_data'] |
| |
| sample_pathways = len(encoded_data) |
| sample_genes = 0 |
| sample_mutations = 0 |
| |
| for pathway_token, chromosomes in encoded_data.items(): |
| stats['num_pathways'].add(pathway_token) |
| |
| for chrom_token, genes in chromosomes.items(): |
| stats['num_chromosomes'].add(chrom_token) |
| |
| for gene_token, mutations in genes.items(): |
| stats['num_genes'].add(gene_token) |
| sample_genes += 1 |
| |
| |
| if 'impact' in mutations: |
| sample_mutations += len(mutations['impact']) |
| |
| stats['mutations_per_sample'].append(sample_mutations) |
| stats['genes_per_sample'].append(sample_genes) |
| stats['pathways_per_sample'].append(sample_pathways) |
| |
| |
| stats['unique_pathways'] = len(stats['num_pathways']) |
| stats['unique_chromosomes'] = len(stats['num_chromosomes']) |
| stats['unique_genes'] = len(stats['num_genes']) |
| |
| |
| if stats['mutations_per_sample']: |
| stats['avg_mutations_per_sample'] = np.mean(stats['mutations_per_sample']) |
| stats['std_mutations_per_sample'] = np.std(stats['mutations_per_sample']) |
| |
| if stats['genes_per_sample']: |
| stats['avg_genes_per_sample'] = np.mean(stats['genes_per_sample']) |
| stats['std_genes_per_sample'] = np.std(stats['genes_per_sample']) |
| |
| |
| del stats['num_pathways'], stats['num_chromosomes'], stats['num_genes'] |
| |
| return stats |
| |
| def __len__(self) -> int: |
| """Number of samples in the dataset.""" |
| return len(self.processed_data) |
| |
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| """Single sample from the dataset.""" |
| |
| if idx >= len(self.processed_data): |
| raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}") |
| |
| sample = self.processed_data[idx].copy() |
| |
| |
| if self.transform: |
| sample['encoded_data'] = self.transform(sample['encoded_data']) |
| |
| |
| if self.labels is not None: |
| label = self.labels[idx] |
| if self.target_transform: |
| label = self.target_transform(label) |
| sample['label'] = label |
| |
| return sample |
| |
| def get_sample_by_id(self, sample_id: str) -> Optional[Dict[str, Any]]: |
| for i, sample in enumerate(self.processed_data): |
| if sample['sample_id'] == sample_id: |
| return self.__getitem__(i) |
| return None |
| |
| def get_statistics(self) -> Dict[str, Any]: |
| return self.stats.copy() |
| |
| def save_dataset(self, save_path: Union[str, Path], format: str = 'pickle') -> None: |
| """ |
| Args: |
| save_path: Path to save the dataset |
| format: Save format ('pickle', 'json') |
| """ |
| save_path = Path(save_path) |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| dataset_info = { |
| 'processed_data': self.processed_data, |
| 'labels': self.labels.tolist() if isinstance(self.labels, np.ndarray) else self.labels, |
| 'stats': self.stats, |
| 'config': self.config.__dict__ if hasattr(self.config, '__dict__') else None |
| } |
| |
| if format.lower() == 'pickle': |
| with open(save_path, 'wb') as f: |
| pickle.dump(dataset_info, f) |
| |
| elif format.lower() == 'json': |
| with open(save_path, 'w') as f: |
| json.dump(dataset_info, f, indent=2, default=str) |
| |
| else: |
| raise ValueError(f"Unsupported save format: {format}") |
| |
| logger.info(f"Dataset saved to {save_path}") |
| |
| @classmethod |
| def load_dataset(cls, |
| load_path: Union[str, Path], |
| tokenizer: HierarchicalVCFTokenizer, |
| format: str = 'auto') -> 'HierarchicalVCFDataset': |
| """ |
| Args: |
| load_path: Path to load the dataset from |
| tokenizer: Tokenizer instance |
| format: Load format ('pickle', 'json', 'auto') |
| |
| Returns: |
| Loaded dataset instance |
| """ |
| load_path = Path(load_path) |
| |
| if not load_path.exists(): |
| raise FileNotFoundError(f"Dataset file not found: {load_path}") |
| |
| |
| if format == 'auto': |
| format = 'pickle' if load_path.suffix == '.pkl' else 'json' |
| |
| |
| if format.lower() == 'pickle': |
| with open(load_path, 'rb') as f: |
| dataset_info = pickle.load(f) |
| |
| elif format.lower() == 'json': |
| with open(load_path, 'r') as f: |
| dataset_info = json.load(f) |
| |
| else: |
| raise ValueError(f"Unsupported load format: {format}") |
| |
| |
| dataset = cls.__new__(cls) |
| dataset.tokenizer = tokenizer |
| dataset.processed_data = dataset_info['processed_data'] |
| dataset.labels = dataset_info.get('labels') |
| dataset.stats = dataset_info.get('stats', {}) |
| dataset.config = dataset_info.get('config', DataConfig()) |
| dataset.transform = None |
| dataset.target_transform = None |
| dataset.cache_processed_data = True |
| |
| return dataset |
|
|
|
|
| class HierarchicalVCFDataModule: |
| """ |
| Manage train/validation/test splits of hierarchical VCF data. |
| """ |
| |
| def __init__(self, |
| data_source: Union[str, Path, Dict], |
| tokenizer: HierarchicalVCFTokenizer, |
| config: Optional[DataConfig] = None, |
| labels: Optional[Union[List, np.ndarray]] = None, |
| train_split: float = 0.8, |
| val_split: float = 0.1, |
| test_split: float = 0.1, |
| stratify: bool = True, |
| random_seed: int = 42): |
| """ |
| Args: |
| data_source: Source of the data |
| tokenizer: Tokenizer for encoding |
| config: Data configuration |
| labels: Labels for supervised learning |
| train_split: Proportion for training |
| val_split: Proportion for validation |
| test_split: Proportion for testing |
| stratify: Whether to stratify splits by labels |
| random_seed: Random seed for reproducibility |
| """ |
| |
| self.config = config or DataConfig() |
| self.tokenizer = tokenizer |
| self.train_split = train_split |
| self.val_split = val_split |
| self.test_split = test_split |
| self.stratify = stratify |
| self.random_seed = random_seed |
| |
| |
| if abs(train_split + val_split + test_split - 1.0) > 1e-6: |
| raise ValueError("Train, validation, and test splits must sum to 1.0") |
| |
| |
| self.full_dataset = HierarchicalVCFDataset( |
| data_source=data_source, |
| tokenizer=tokenizer, |
| config=config, |
| labels=labels |
| ) |
| |
| |
| self.train_dataset, self.val_dataset, self.test_dataset = self._create_splits() |
| |
| logger.info(f"Data module initialized:") |
| logger.info(f" Train: {len(self.train_dataset)} samples") |
| logger.info(f" Validation: {len(self.val_dataset)} samples") |
| logger.info(f" Test: {len(self.test_dataset)} samples") |
| |
| def _create_splits(self) -> Tuple[Dataset, Dataset, Dataset]: |
| |
| np.random.seed(self.random_seed) |
| |
| indices = np.arange(len(self.full_dataset)) |
| |
| if self.stratify and self.full_dataset.labels is not None: |
| |
| from sklearn.model_selection import train_test_split |
| |
| |
| train_idx, temp_idx = train_test_split( |
| indices, |
| test_size=(self.val_split + self.test_split), |
| stratify=[self.full_dataset.labels[i] for i in indices], |
| random_state=self.random_seed |
| ) |
| |
| |
| if self.test_split > 0: |
| val_idx, test_idx = train_test_split( |
| temp_idx, |
| test_size=self.test_split / (self.val_split + self.test_split), |
| stratify=[self.full_dataset.labels[i] for i in temp_idx], |
| random_state=self.random_seed |
| ) |
| else: |
| val_idx = temp_idx |
| test_idx = np.array([]) |
| |
| else: |
| |
| np.random.shuffle(indices) |
| |
| train_end = int(self.train_split * len(indices)) |
| val_end = int((self.train_split + self.val_split) * len(indices)) |
| |
| train_idx = indices[:train_end] |
| val_idx = indices[train_end:val_end] |
| test_idx = indices[val_end:] |
| |
| |
| train_dataset = self._create_subset(train_idx) |
| val_dataset = self._create_subset(val_idx) |
| test_dataset = self._create_subset(test_idx) |
| |
| return train_dataset, val_dataset, test_dataset |
| |
| def _create_subset(self, indices: np.ndarray) -> Dataset: |
| """Create a subset dataset from indices.""" |
| |
| subset_data = [self.full_dataset.processed_data[i] for i in indices] |
| subset_labels = None |
| |
| if self.full_dataset.labels is not None: |
| if isinstance(self.full_dataset.labels, np.ndarray): |
| subset_labels = self.full_dataset.labels[indices] |
| else: |
| subset_labels = [self.full_dataset.labels[i] for i in indices] |
| |
| |
| dataset = HierarchicalVCFDataset.__new__(HierarchicalVCFDataset) |
| dataset.tokenizer = self.tokenizer |
| dataset.config = self.config |
| dataset.processed_data = subset_data |
| dataset.labels = subset_labels |
| dataset.transform = None |
| dataset.target_transform = None |
| dataset.cache_processed_data = True |
| dataset.stats = dataset._compute_statistics() |
| |
| return dataset |
| |
| def get_dataloaders(self, |
| batch_size: int = 16, |
| num_workers: int = 0, |
| collate_fn: Optional[Callable] = None) -> Tuple[DataLoader, DataLoader, DataLoader]: |
| """ |
| Args: |
| batch_size: Batch size for data loading |
| num_workers: Number of worker processes |
| collate_fn: Custom collate function |
| |
| Returns: |
| Tuple of (train_loader, val_loader, test_loader) |
| """ |
| |
| if collate_fn is None: |
| collate_fn = HierarchicalDataCollator(self.tokenizer) |
| |
| train_loader = DataLoader( |
| self.train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| collate_fn=collate_fn |
| ) |
| |
| val_loader = DataLoader( |
| self.val_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| collate_fn=collate_fn |
| ) |
| |
| test_loader = DataLoader( |
| self.test_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| collate_fn=collate_fn |
| ) |
| |
| return train_loader, val_loader, test_loader |
|
|
|
|
| class HuggingFaceDatasetAdapter: |
| """ |
| Convert hierarchical VCF data to Hugging Face Dataset format. |
| """ |
| |
| def __init__(self, vcf_dataset: HierarchicalVCFDataset): |
| self.vcf_dataset = vcf_dataset |
| |
| def to_huggingface_dataset(self) -> DatasetDict: |
| """ |
| Returns: |
| HuggingFace DatasetDict |
| """ |
| |
| |
| flattened_data = [] |
| |
| for sample in self.vcf_dataset.processed_data: |
| sample_id = sample['sample_id'] |
| encoded_data = sample['encoded_data'] |
| |
| |
| flattened_sample = { |
| 'sample_id': sample_id, |
| 'pathways': list(encoded_data.keys()), |
| 'num_pathways': len(encoded_data), |
| 'encoded_mutations': self._flatten_mutations(encoded_data) |
| } |
| |
| flattened_data.append(flattened_sample) |
| |
| |
| if self.vcf_dataset.labels is not None: |
| for i, sample in enumerate(flattened_data): |
| sample['label'] = self.vcf_dataset.labels[i] |
| |
| |
| hf_dataset = HFDataset.from_list(flattened_data) |
| |
| return DatasetDict({'train': hf_dataset}) |
| |
| def _flatten_mutations(self, encoded_data: Dict) -> Dict[str, List]: |
| """Flatten hierarchical mutations for HF compatibility.""" |
| |
| all_impacts = [] |
| all_refs = [] |
| all_alts = [] |
| |
| for pathway_token, chromosomes in encoded_data.items(): |
| for chrom_token, genes in chromosomes.items(): |
| for gene_token, mutations in genes.items(): |
| if 'impact' in mutations: |
| all_impacts.extend(mutations['impact']) |
| if 'ref' in mutations: |
| all_refs.extend(mutations['ref']) |
| if 'alt' in mutations: |
| all_alts.extend(mutations['alt']) |
| |
| return { |
| 'impacts': all_impacts, |
| 'refs': all_refs, |
| 'alts': all_alts |
| } |
|
|
|
|
| def create_dataset_from_config(config_manager: ConfigManager, |
| tokenizer: HierarchicalVCFTokenizer, |
| labels: Optional[List] = None) -> HierarchicalVCFDataset: |
| |
| data_config = config_manager.data_config |
| |
| if not data_config.vcf_file_path: |
| raise ValueError("VCF file path not specified in configuration") |
| |
| return HierarchicalVCFDataset( |
| data_source=data_config.vcf_file_path, |
| tokenizer=tokenizer, |
| config=data_config, |
| labels=labels |
| ) |
|
|
|
|
| def create_data_module_from_config(config_manager: ConfigManager, |
| tokenizer: HierarchicalVCFTokenizer, |
| labels: Optional[List] = None) -> HierarchicalVCFDataModule: |
| |
| data_config = config_manager.data_config |
| |
| if not data_config.vcf_file_path: |
| raise ValueError("VCF file path not specified in configuration") |
| |
| return HierarchicalVCFDataModule( |
| data_source=data_config.vcf_file_path, |
| tokenizer=tokenizer, |
| config=data_config, |
| labels=labels |
| ) |
|
|
|
|
| |
| def create_synthetic_labels(dataset: HierarchicalVCFDataset, |
| label_type: str = 'random', |
| num_classes: int = 2) -> np.ndarray: |
| """ |
| Create synthetic labels for testing purposes. |
| |
| Args: |
| dataset: VCF dataset |
| label_type: Type of labels ('random', 'mutation_count_based') |
| num_classes: Number of classes for classification |
| |
| Returns: |
| Array of synthetic labels |
| """ |
| |
| num_samples = len(dataset) |
| |
| if label_type == 'random': |
| return np.random.randint(0, num_classes, size=num_samples) |
| |
| elif label_type == 'mutation_count_based': |
| |
| mutation_counts = dataset.stats['mutations_per_sample'] |
| threshold = np.median(mutation_counts) |
| |
| labels = [] |
| for count in mutation_counts: |
| if num_classes == 2: |
| labels.append(1 if count > threshold else 0) |
| else: |
| |
| percentiles = np.linspace(0, 100, num_classes + 1) |
| thresholds = np.percentile(mutation_counts, percentiles[1:-1]) |
| |
| label = 0 |
| for i, t in enumerate(thresholds): |
| if count > t: |
| label = i + 1 |
| else: |
| break |
| labels.append(label) |
| |
| return np.array(labels) |
| |
| else: |
| raise ValueError(f"Unknown label type: {label_type}") |
|
|
|
|
| |
| if __name__ == "__main__": |
| from tokenizer import create_tokenizer_from_config |
| |
| |
| config_manager = ConfigManager() |
| config_manager.data_config.vcf_file_path = "example_data.json" |
| |
| |
| tokenizer = create_tokenizer_from_config(config_manager) |
| |
| |
| example_data = { |
| 'sample1': { |
| 'pathway1': { |
| 'chr1': { |
| 'gene1': [ |
| {'impact': 'HIGH', 'reference': 'A', 'alternate': 'T'}, |
| {'impact': 'MODERATE', 'reference': 'G', 'alternate': 'C'} |
| ] |
| } |
| } |
| }, |
| 'sample2': { |
| 'pathway2': { |
| 'chr2': { |
| 'gene2': [ |
| {'impact': 'LOW', 'reference': 'T', 'alternate': 'A'} |
| ] |
| } |
| } |
| } |
| } |
| |
| |
| tokenizer.build_vocabulary(example_data) |
| |
| |
| dataset = HierarchicalVCFDataset( |
| data_source=example_data, |
| tokenizer=tokenizer |
| ) |
| |
| |
| labels = create_synthetic_labels(dataset, label_type='random', num_classes=2) |
| dataset.labels = labels |
| |
| |
| data_module = HierarchicalVCFDataModule( |
| data_source=example_data, |
| tokenizer=tokenizer, |
| labels=labels, |
| train_split=0.6, |
| val_split=0.2, |
| test_split=0.2 |
| ) |
| |
| |
| train_loader, val_loader, test_loader = data_module.get_dataloaders(batch_size=2) |
| |
| |
| for batch in train_loader: |
| print(f"Batch size: {batch['batch_size']}") |
| print(f"Sample IDs: {[s.get('sample_id', 'N/A') for s in batch['samples']]}") |
| break |
| |
| print(f"Dataset statistics: {dataset.get_statistics()}") |