| """ |
| This module contains all configuration parameters for the VCF processing pipeline |
| """ |
|
|
| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional, Any |
| import json |
| import os |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| """Configurations""" |
| |
| |
| embed_dim: int = 32 |
| transformer_dim: int = 128 |
| |
| |
| nhead: int = 8 |
| num_layers: int = 2 |
| dropout: float = 0.1 |
| |
| |
| num_classes: int = 2 |
| hidden_dims: List[int] = field(default_factory=lambda: [256, 128]) |
| |
| |
| learning_rate: float = 1e-4 |
| batch_size: int = 16 |
| max_epochs: int = 100 |
| early_stopping_patience: int = 10 |
| |
| |
| max_mutations_per_gene: int = 100 |
| max_genes_per_chromosome: int = 1000 |
| max_chromosomes_per_pathway: int = 50 |
| max_pathways_per_sample: int = 100 |
|
|
|
|
| @dataclass |
| class DataConfig: |
| """Configurations""" |
| |
| |
| vcf_file_path: Optional[str] = None |
| gene_annotation_path: Optional[str] = None |
| pathway_mapping_path: Optional[str] = None |
| output_dir: str = "./outputs" |
| cache_dir: str = "./cache" |
| |
| |
| supported_impacts: List[str] = field(default_factory=lambda: [ |
| "HIGH", "MODERATE", "LOW", "MODIFIER" |
| ]) |
| supported_chromosomes: List[str] = field(default_factory=lambda: [ |
| "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", |
| "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", |
| "21", "22", "X", "Y", "MT" |
| ]) |
| |
| |
| special_tokens: Dict[str, str] = field(default_factory=lambda: { |
| "pad_token": "[PAD]", |
| "unk_token": "[UNK]", |
| "sep_token": "[SEP]", |
| "cls_token": "[CLS]" |
| }) |
| |
| |
| min_mutations_per_sample: int = 1 |
| max_mutations_per_sample: int = 10000 |
|
|
|
|
| @dataclass |
| class HuggingFaceConfig: |
| """Configurations""" |
| |
| model_name: str = "GvEM" |
| model_version: str = "1.0.0" |
| model_description: str = "Genomic Variant Embedding Model" |
| |
| |
| push_to_hub: bool = False |
| hub_model_id: Optional[str] = None |
| hub_token: Optional[str] = None |
| |
| |
| license: str = "apache-2.0" |
| tags: List[str] = field(default_factory=lambda: [ |
| "genomics", "vcf", "transformer", "hierarchical", "mutations" |
| ]) |
| |
| |
| repository_url: Optional[str] = None |
| paper_url: Optional[str] = None |
|
|
|
|
| class ConfigManager: |
| """Manage configurations""" |
| |
| def __init__(self, config_path: Optional[str] = None): |
| self.config_path = config_path or "config.json" |
| self.model_config = ModelConfig() |
| self.data_config = DataConfig() |
| self.hf_config = HuggingFaceConfig() |
| |
| def load_config(self, config_path: Optional[str] = None) -> None: |
| path = config_path or self.config_path |
| |
| if os.path.exists(path): |
| with open(path, 'r') as f: |
| config_dict = json.load(f) |
| |
| |
| if 'model' in config_dict: |
| self._update_dataclass(self.model_config, config_dict['model']) |
| if 'data' in config_dict: |
| self._update_dataclass(self.data_config, config_dict['data']) |
| if 'huggingface' in config_dict: |
| self._update_dataclass(self.hf_config, config_dict['huggingface']) |
| |
| def save_config(self, config_path: Optional[str] = None) -> None: |
| path = config_path or self.config_path |
| |
| config_dict = { |
| 'model': self._dataclass_to_dict(self.model_config), |
| 'data': self._dataclass_to_dict(self.data_config), |
| 'huggingface': self._dataclass_to_dict(self.hf_config) |
| } |
| |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| with open(path, 'w') as f: |
| json.dump(config_dict, f, indent=2) |
| |
| def _update_dataclass(self, dataclass_obj: Any, update_dict: Dict) -> None: |
| """Update dataclass fields from dictionary.""" |
| for key, value in update_dict.items(): |
| if hasattr(dataclass_obj, key): |
| setattr(dataclass_obj, key, value) |
| |
| def _dataclass_to_dict(self, dataclass_obj: Any) -> Dict: |
| """Convert dataclass to dictionary.""" |
| result = {} |
| for key, value in dataclass_obj.__dict__.items(): |
| if not key.startswith('_'): |
| result[key] = value |
| return result |
| |
| def validate_config(self) -> bool: |
| """Validate configuration parameters.""" |
| |
| assert self.model_config.embed_dim > 0, "embed_dim must be positive" |
| assert self.model_config.nhead > 0, "nhead must be positive" |
| assert self.model_config.num_classes > 1, "num_classes must be > 1" |
| assert 0 <= self.model_config.dropout <= 1, "dropout must be in [0, 1]" |
| |
| |
| assert self.data_config.min_mutations_per_sample > 0, "min_mutations_per_sample must be positive" |
| assert self.data_config.max_mutations_per_sample > self.data_config.min_mutations_per_sample, \ |
| "max_mutations_per_sample must be > min_mutations_per_sample" |
| |
| return True |
| |
| def get_model_config_dict(self) -> Dict: |
| return { |
| 'architectures': ['HierarchicalVCFModel'], |
| 'model_type': 'hierarchical-vcf', |
| **self._dataclass_to_dict(self.model_config) |
| } |
|
|
| default_config = ConfigManager() |
|
|
| EXAMPLE_CONFIG = { |
| "model": { |
| "embed_dim": 64, |
| "transformer_dim": 256, |
| "nhead": 8, |
| "num_layers": 3, |
| "num_classes": 5, |
| "learning_rate": 5e-4, |
| "batch_size": 32 |
| }, |
| "data": { |
| "vcf_file_path": "/path/to/variants.vcf", |
| "gene_annotation_path": "/path/to/gene_annotations.json", |
| "pathway_mapping_path": "/path/to/pathway_mappings.json", |
| "output_dir": "./results", |
| "min_mutations_per_sample": 5, |
| "max_mutations_per_sample": 5000 |
| }, |
| "huggingface": { |
| "model_name": "my-vcf-model", |
| "push_to_hub": True, |
| "hub_model_id": "username/my-vcf-model", |
| "license": "mit" |
| } |
| } |