| """ |
| Tokenization for VCF data with support for hierarchical structures |
| """ |
|
|
| import json |
| import pickle |
| import logging |
| from pathlib import Path |
| from collections import defaultdict, Counter |
| from typing import Dict, List, Tuple, Optional, Union, Any |
| import numpy as np |
|
|
| from transformers import PreTrainedTokenizer |
| from transformers.tokenization_utils import AddedToken |
|
|
| from config import DataConfig, ConfigManager |
| from parser import MutationRecord |
|
|
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class HierarchicalVCFTokenizer(PreTrainedTokenizer): |
| |
| vocab_files_names = { |
| "vocab_file": "vocab.json", |
| "mutation_vocab_file": "mutation_vocab.json" |
| } |
| |
| def __init__(self, |
| vocab_file: Optional[str] = None, |
| mutation_vocab_file: Optional[str] = None, |
| config: Optional[DataConfig] = None, |
| **kwargs): |
| |
| |
| self.config = config or DataConfig() |
| |
| |
| special_tokens = self.config.special_tokens |
| pad_token = special_tokens.get("pad_token", "[PAD]") |
| unk_token = special_tokens.get("unk_token", "[UNK]") |
| sep_token = special_tokens.get("sep_token", "[SEP]") |
| cls_token = special_tokens.get("cls_token", "[CLS]") |
| |
| super().__init__( |
| pad_token=pad_token, |
| unk_token=unk_token, |
| sep_token=sep_token, |
| cls_token=cls_token, |
| **kwargs |
| ) |
| |
| |
| self.mutation_fields = ['impact', 'ref', 'alt', 'chromosome', 'pathway', 'gene'] |
| self.field_vocabs = {} |
| |
| |
| self._initialize_vocabularies() |
| |
| |
| if vocab_file and Path(vocab_file).exists(): |
| self.load_vocabulary(vocab_file) |
| |
| if mutation_vocab_file and Path(mutation_vocab_file).exists(): |
| self.load_mutation_vocabulary(mutation_vocab_file) |
| |
| |
| self.tokenization_stats = { |
| 'total_samples': 0, |
| 'total_mutations': 0, |
| 'vocab_sizes': {} |
| } |
| |
| def _initialize_vocabularies(self) -> None: |
| for field in self.mutation_fields: |
| self.field_vocabs[field] = { |
| self.pad_token: 0, |
| self.unk_token: 1, |
| self.sep_token: 2, |
| self.cls_token: 3 |
| } |
| |
| |
| self._add_common_genomic_tokens() |
| |
| def _add_common_genomic_tokens(self) -> None: |
| """To be made scalable and dynamic""" |
| |
| common_impacts = ["HIGH", "MODERATE", "LOW", "MODIFIER"] |
| for impact in common_impacts: |
| if impact not in self.field_vocabs['impact']: |
| self.field_vocabs['impact'][impact] = len(self.field_vocabs['impact']) |
| |
| |
| nucleotides = ["A", "T", "G", "C", "N", "-"] |
| for nt in nucleotides: |
| for field in ['ref', 'alt']: |
| if nt not in self.field_vocabs[field]: |
| self.field_vocabs[field][nt] = len(self.field_vocabs[field]) |
| |
| |
| chromosomes = [str(i) for i in range(1, 23)] + ["X", "Y", "MT"] |
| for chrom in chromosomes: |
| if chrom not in self.field_vocabs['chromosome']: |
| self.field_vocabs['chromosome'][chrom] = len(self.field_vocabs['chromosome']) |
| |
| def build_vocabulary(self, hierarchical_data: Dict[str, Any]) -> None: |
| """ |
| Args: |
| hierarchical_data: Parsed VCF data structure |
| """ |
| logger.info("Building vocabularies from hierarchical data...") |
| |
| vocab_counters = {field: Counter() for field in self.mutation_fields} |
| |
| for sample_id, pathways in hierarchical_data.items(): |
| for pathway_id, chromosomes in pathways.items(): |
| |
| vocab_counters['pathway'][pathway_id] += 1 |
| |
| for chrom_id, genes in chromosomes.items(): |
| |
| vocab_counters['chromosome'][chrom_id] += 1 |
| |
| for gene_id, mutations in genes.items(): |
| |
| vocab_counters['gene'][gene_id] += 1 |
| |
| for mutation in mutations: |
| if isinstance(mutation, MutationRecord): |
| |
| vocab_counters['impact'][mutation.impact] += 1 |
| vocab_counters['ref'][mutation.reference] += 1 |
| vocab_counters['alt'][mutation.alternate] += 1 |
| elif isinstance(mutation, dict): |
| |
| vocab_counters['impact'][mutation.get('impact', self.unk_token)] += 1 |
| vocab_counters['ref'][mutation.get('reference', self.unk_token)] += 1 |
| vocab_counters['alt'][mutation.get('alternate', self.unk_token)] += 1 |
| |
| |
| for field, counter in vocab_counters.items(): |
| for token, count in counter.most_common(): |
| if token and token not in self.field_vocabs[field]: |
| self.field_vocabs[field][token] = len(self.field_vocabs[field]) |
| |
| |
| self.tokenization_stats['vocab_sizes'] = { |
| field: len(vocab) for field, vocab in self.field_vocabs.items() |
| } |
| |
| logger.info(f"Vocabulary sizes: {self.tokenization_stats['vocab_sizes']}") |
| |
| def encode_hierarchical_sample(self, sample_data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Encode a single hierarchical sample into tokenized format. |
| Args: |
| sample_data: Single sample from hierarchical data |
| Returns: |
| Encoded sample with tokenized values |
| """ |
| encoded_sample = {} |
| |
| for pathway_id, chromosomes in sample_data.items(): |
| |
| pathway_token = self.field_vocabs['pathway'].get( |
| pathway_id, self.field_vocabs['pathway'][self.unk_token] |
| ) |
| |
| encoded_sample[pathway_token] = {} |
| |
| for chrom_id, genes in chromosomes.items(): |
| |
| chrom_token = self.field_vocabs['chromosome'].get( |
| chrom_id, self.field_vocabs['chromosome'][self.unk_token] |
| ) |
| |
| encoded_sample[pathway_token][chrom_token] = {} |
| |
| for gene_id, mutations in genes.items(): |
| |
| gene_token = self.field_vocabs['gene'].get( |
| gene_id, self.field_vocabs['gene'][self.unk_token] |
| ) |
| |
| |
| encoded_mutations = self._encode_mutations(mutations) |
| encoded_sample[pathway_token][chrom_token][gene_token] = encoded_mutations |
| |
| return encoded_sample |
| |
| def _encode_mutations(self, mutations: List[Union[MutationRecord, Dict]]) -> Dict[str, List[int]]: |
| encoded_mutations = { |
| 'impact': [], |
| 'ref': [], |
| 'alt': [] |
| } |
| |
| for mutation in mutations: |
| if isinstance(mutation, MutationRecord): |
| impact = mutation.impact |
| ref = mutation.reference |
| alt = mutation.alternate |
| elif isinstance(mutation, dict): |
| impact = mutation.get('impact', self.unk_token) |
| ref = mutation.get('reference', self.unk_token) |
| alt = mutation.get('alternate', self.unk_token) |
| else: |
| continue |
| |
| |
| encoded_mutations['impact'].append( |
| self.field_vocabs['impact'].get(impact, self.field_vocabs['impact'][self.unk_token]) |
| ) |
| encoded_mutations['ref'].append( |
| self.field_vocabs['ref'].get(ref, self.field_vocabs['ref'][self.unk_token]) |
| ) |
| encoded_mutations['alt'].append( |
| self.field_vocabs['alt'].get(alt, self.field_vocabs['alt'][self.unk_token]) |
| ) |
| |
| return encoded_mutations |
| |
| def encode_batch(self, batch_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Encode a batch of hierarchical samples. |
| Args: |
| batch_data: List of sample dictionaries |
| Returns: |
| List of encoded samples |
| """ |
| encoded_batch = [] |
| |
| for sample_data in batch_data: |
| encoded_sample = self.encode_hierarchical_sample(sample_data) |
| encoded_batch.append(encoded_sample) |
| |
| self.tokenization_stats['total_samples'] += len(batch_data) |
| |
| return encoded_batch |
| |
| def decode_tokens(self, field: str, token_ids: List[int]) -> List[str]: |
| """ |
| Decode token IDs back to original values. |
| Args: |
| field: Field name ('impact', 'ref', 'alt', etc.) |
| token_ids: List of token IDs |
| Returns: |
| List of decoded tokens |
| """ |
| if field not in self.field_vocabs: |
| raise ValueError(f"Unknown field: {field}") |
| |
| id_to_token = {v: k for k, v in self.field_vocabs[field].items()} |
| return [id_to_token.get(token_id, self.unk_token) for token_id in token_ids] |
| |
| def get_vocab_size(self, field: str) -> int: |
| """Get vocabulary size for a specific field.""" |
| if field not in self.field_vocabs: |
| raise ValueError(f"Unknown field: {field}") |
| return len(self.field_vocabs[field]) |
| |
| def get_all_vocab_sizes(self) -> Dict[str, int]: |
| """Get vocabulary sizes for all fields.""" |
| return {field: len(vocab) for field, vocab in self.field_vocabs.items()} |
| |
| def save_vocabulary(self, save_directory: Union[str, Path], filename_prefix: Optional[str] = None) -> Tuple[str, ...]: |
| """ |
| Args: |
| save_directory: Directory to save vocabularies |
| filename_prefix: Optional prefix for filenames |
| |
| Returns: |
| Tuple of saved file paths |
| """ |
| save_directory = Path(save_directory) |
| save_directory.mkdir(parents=True, exist_ok=True) |
| |
| prefix = f"{filename_prefix}_" if filename_prefix else "" |
| |
| |
| mutation_vocab_file = save_directory / f"{prefix}mutation_vocab.json" |
| with open(mutation_vocab_file, 'w') as f: |
| json.dump(self.field_vocabs, f, indent=2) |
| |
| |
| config_file = save_directory / f"{prefix}tokenizer_config.json" |
| config_data = { |
| 'tokenizer_class': self.__class__.__name__, |
| 'special_tokens': { |
| 'pad_token': self.pad_token, |
| 'unk_token': self.unk_token, |
| 'sep_token': self.sep_token, |
| 'cls_token': self.cls_token |
| }, |
| 'vocab_sizes': self.get_all_vocab_sizes(), |
| 'mutation_fields': self.mutation_fields |
| } |
| |
| with open(config_file, 'w') as f: |
| json.dump(config_data, f, indent=2) |
| |
| logger.info(f"Vocabularies saved to {save_directory}") |
| |
| return str(mutation_vocab_file), str(config_file) |
| |
| def load_vocabulary(self, vocab_file: Union[str, Path]) -> None: |
| vocab_file = Path(vocab_file) |
| |
| if not vocab_file.exists(): |
| raise FileNotFoundError(f"Vocabulary file not found: {vocab_file}") |
| |
| with open(vocab_file, 'r') as f: |
| vocab_data = json.load(f) |
| |
| |
| for field, vocab in vocab_data.items(): |
| if field in self.mutation_fields: |
| self.field_vocabs[field] = vocab |
| |
| logger.info(f"Vocabularies loaded from {vocab_file}") |
| |
| def load_mutation_vocabulary(self, mutation_vocab_file: Union[str, Path]) -> None: |
| """Load mutation-specific vocabularies from file.""" |
| self.load_vocabulary(mutation_vocab_file) |
| |
| def create_padding_masks(self, encoded_sample: Dict[str, Any], max_lengths: Dict[str, int]) -> Dict[str, Any]: |
| """ |
| Create padding masks for hierarchical data. |
| Args: |
| encoded_sample: Encoded sample data |
| max_lengths: Maximum lengths for each level |
| Returns: |
| Sample with padding masks |
| """ |
| masked_sample = {} |
| |
| for pathway_token, chromosomes in encoded_sample.items(): |
| masked_sample[pathway_token] = {} |
| |
| for chrom_token, genes in chromosomes.items(): |
| masked_sample[pathway_token][chrom_token] = {} |
| |
| for gene_token, mutations in genes.items(): |
| masked_mutations = {} |
| |
| for field, token_list in mutations.items(): |
| max_len = max_lengths.get(f'mutations_{field}', 100) |
| |
| |
| if len(token_list) < max_len: |
| padded_list = token_list + [self.field_vocabs[field][self.pad_token]] * (max_len - len(token_list)) |
| mask = [1] * len(token_list) + [0] * (max_len - len(token_list)) |
| else: |
| padded_list = token_list[:max_len] |
| mask = [1] * max_len |
| |
| masked_mutations[field] = { |
| 'tokens': padded_list, |
| 'mask': mask |
| } |
| |
| masked_sample[pathway_token][chrom_token][gene_token] = masked_mutations |
| |
| return masked_sample |
| |
| def get_tokenization_statistics(self) -> Dict[str, Any]: |
| stats = self.tokenization_stats.copy() |
| stats['vocab_sizes'] = self.get_all_vocab_sizes() |
| return stats |
| |
| |
| @property |
| def vocab_size(self) -> int: |
| return sum(len(vocab) for vocab in self.field_vocabs.values()) |
| |
| def get_vocab(self) -> Dict[str, int]: |
| combined_vocab = {} |
| offset = 0 |
| |
| for field, vocab in self.field_vocabs.items(): |
| for token, idx in vocab.items(): |
| combined_vocab[f"{field}:{token}"] = idx + offset |
| offset += len(vocab) |
| |
| return combined_vocab |
| |
| def _tokenize(self, text: str) -> List[str]: |
| |
| |
| return text.split() |
| |
| def _convert_token_to_id(self, token: str) -> int: |
| |
| if ':' in token: |
| field, actual_token = token.split(':', 1) |
| if field in self.field_vocabs: |
| return self.field_vocabs[field].get(actual_token, self.field_vocabs[field][self.unk_token]) |
| |
| return self.field_vocabs.get('impact', {}).get(self.unk_token, 1) |
| |
| def _convert_id_to_token(self, index: int) -> str: |
| |
| for field, vocab in self.field_vocabs.items(): |
| id_to_token = {v: k for k, v in vocab.items()} |
| if index in id_to_token: |
| return f"{field}:{id_to_token[index]}" |
| |
| return self.unk_token |
|
|
|
|
| class HierarchicalDataCollator: |
| |
| def __init__(self, tokenizer: HierarchicalVCFTokenizer, max_lengths: Optional[Dict[str, int]] = None): |
| self.tokenizer = tokenizer |
| self.max_lengths = max_lengths or { |
| 'mutations_impact': 50, |
| 'mutations_ref': 50, |
| 'mutations_alt': 50, |
| 'genes_per_chromosome': 100, |
| 'chromosomes_per_pathway': 25, |
| 'pathways_per_sample': 50 |
| } |
| |
| def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| """ |
| Collate batch of hierarchical samples. |
| Args: |
| batch: List of encoded hierarchical samples |
| Returns: |
| Collated batch ready for model input |
| """ |
| collated_batch = { |
| 'samples': [], |
| 'batch_size': len(batch), |
| 'metadata': { |
| 'num_pathways': [], |
| 'num_chromosomes': [], |
| 'num_genes': [], |
| 'num_mutations': [] |
| } |
| } |
| |
| for sample in batch: |
| |
| masked_sample = self.tokenizer.create_padding_masks(sample, self.max_lengths) |
| collated_batch['samples'].append(masked_sample) |
| |
| |
| num_pathways = len(sample) |
| num_chromosomes = sum(len(chroms) for chroms in sample.values()) |
| num_genes = sum( |
| len(genes) for chroms in sample.values() |
| for genes in chroms.values() |
| ) |
| num_mutations = sum( |
| len(mutations.get('impact', [])) |
| for chroms in sample.values() |
| for genes in chroms.values() |
| for mutations in genes.values() |
| ) |
| |
| collated_batch['metadata']['num_pathways'].append(num_pathways) |
| collated_batch['metadata']['num_chromosomes'].append(num_chromosomes) |
| collated_batch['metadata']['num_genes'].append(num_genes) |
| collated_batch['metadata']['num_mutations'].append(num_mutations) |
| |
| return collated_batch |
|
|
|
|
| def create_tokenizer_from_config(config_manager: ConfigManager) -> HierarchicalVCFTokenizer: |
| """Create tokenizer from configuration manager.""" |
| return HierarchicalVCFTokenizer(config=config_manager.data_config) |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| config_manager = ConfigManager() |
| tokenizer = create_tokenizer_from_config(config_manager) |
| |
| |
| example_data = { |
| 'sample1': { |
| 'pathway1': { |
| 'chr1': { |
| 'gene1': [ |
| { |
| 'impact': 'HIGH', |
| 'reference': 'A', |
| 'alternate': 'T' |
| } |
| ] |
| } |
| } |
| } |
| } |
| |
| |
| tokenizer.build_vocabulary({'sample1': example_data['sample1']}) |
| |
| |
| encoded = tokenizer.encode_hierarchical_sample(example_data['sample1']) |
| print(f"Encoded sample: {encoded}") |
| |
| |
| tokenizer.save_vocabulary("./tokenizer_files") |
| |
| print(f"Tokenization statistics: {tokenizer.get_tokenization_statistics()}") |