| """ |
| Module: tokenization_args.py |
| |
| This module defines the `TokenizationArgs` dataclass, which encapsulates all the configurable parameters |
| required for the tokenization process in the TEDDY project. These parameters control how gene expression |
| data and biological annotations are tokenized for training. |
| |
| Main Features: |
| - Provides a structured way to define and manage tokenization arguments. |
| - Supports configuration for gene selection, sequence truncation, and annotation inclusion. |
| - Includes options for handling PerturbSeq-specific flags and preprocessing steps. |
| - Allows for flexible mapping of biological annotations (e.g., disease, tissue, cell type, sex). |
| - Enables reproducibility through random seed control for gene selection. |
| |
| Dependencies: |
| - `dataclasses`: For defining the `TokenizationArgs` dataclass. |
| |
| Usage: |
| 1. Import the `TokenizationArgs` class: |
| ```python |
| from teddy.tokenizer.tokenization_args import TokenizationArgs" |
| ``` |
| 2. Define tokenization arguments for a specific tokenization task: |
| ```python |
| tokenization_args = TokenizationArgs( |
| tokenizer_name_or_path="path/to/tokenizer", |
| ... |
| ) |
| ``` |
| 3. Pass the `tokenization_args` object to the tokenization function: |
| ```python |
| tokenized_data = tokenize(data, tokenization_args) |
| ``` |
| """ |
|
|
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class TokenizationArgs: |
| tokenizer_name_or_path: str = field(metadata={"help": "Path to tokenizer used."}) |
| gene_id_column: str = field(default="index", metadata={"help": "Field to use while accessing gene_ids for values."}) |
| random_genes: bool = field( |
| default=False, metadata={"help": "whether we want random genes (True) selection or top expressed ones (False)"} |
| ) |
| include_zero_genes: bool = field(default=False, metadata={"help": "Path to tokenizer used."}) |
| add_cls: bool = field(default=False, metadata={"help": "Whether to add cls token to the start of the sequence."}) |
| cls_token_id: int = field(default=None, metadata={"help": "Token id for cls token."}) |
| perturbseq: bool = field( |
| default=False, |
| metadata={"help": "[PerturbSeq specific flag] Whether to add perturbation token during tokenization."}, |
| ) |
| tokenize_perturbseq_for_train: bool = field( |
| default=True, |
| metadata={ |
| "help": "[PerturbSeq specific flag] Whether to tokenize labels to prepare data for training or to simply prepare tokennized perturbation flags for inference." |
| }, |
| ) |
| add_tokens: tuple = field( |
| default=(), |
| metadata={ |
| "help": "Enter a tuple of string values for tokens. Will be pre-pended to the gene id sequence. Can be used instead of add_cls" |
| }, |
| ) |
|
|
| add_disease_annotation: bool = field(default=False) |
|
|
| label_column: str = field( |
| default=None, metadata={"help": "Which column to use as a label for a classification task."} |
| ) |
| max_shard_samples: int = field(default=500, metadata={"help": "Number of samples included in sharding."}) |
| max_seq_len: int = field(default=3001, metadata={"help": "Max seq length used for data processing"}) |
| pad_length: int = field(default=3001, metadata={"help": "Pad sequence to x length so that all arrays in all batches are same length"}) |
| truncation_method: str = field( |
| default="max", |
| metadata={ |
| "help": "Indicate here how to restrict the number of genes to obtain max_seq_len from the full set of expresison values. Options: max, random" |
| }, |
| ) |
| bins: int = field(default=None, metadata={"help": "Number of bins used when required for data processing"}) |
|
|
| rescale_labels: bool = field(default=False, metadata={"help": "If true, labels are binned or continiously ranked"}) |
|
|
| continuous_rank: bool = field( |
| default=False, metadata={"help": "If true, gene values are overwritten with linspace[-1, 1] by rank."} |
| ) |
|
|
| bio_annotations: bool = field( |
| default=False, metadata={"help": "If true, include disease, tissue type, cell type, sex"} |
| ) |
|
|
| bio_annotation_masking_prob: float = field( |
| default=0.15, metadata={"help": "Mask annotation tokens with this probability"} |
| ) |
|
|
| disease_mapping: str = field( |
| default=None, metadata={"help": "Path to json mapping from disease names to standard disease categories"} |
| ) |
|
|
| tissue_mapping: str = field( |
| default=None, metadata={"help": "Path to json mapping from tissue names to standard tissue categories"} |
| ) |
|
|
| cell_mapping: str = field( |
| default=None, metadata={"help": "Path to json mapping from cell type names to standard cell types"} |
| ) |
|
|
| sex_mapping: str = field( |
| default=None, metadata={"help": "Path to json mapping from sex names to standard sex categories"} |
| ) |
|
|
| load_dir: str = field(default="", metadata={"help": "Directory where h5ad data is loaded from."}) |
|
|
| save_dir: str = field( |
| default="", |
| metadata={ |
| "help": "Directory where tokenization function will save data. tokenize() saves tokenized in data_path.replace(load_dir, save_dir)" |
| }, |
| ) |
|
|
| gene_seed: int = field(default=42, metadata={"help": "Random seed that controls randomness of gene selection"}) |
|
|