| """ |
| Module: classification_heads.py |
| |
| This module defines various classification and decoder heads for use in transformer-based models, |
| specifically tailored for single-cell biology tasks. These heads are designed to handle tasks such as |
| classification, regression, and expression value prediction, and they integrate seamlessly with |
| transformer architectures. |
| |
| Main Features: |
| - **ClsDecoder**: A simple decoder for classification tasks, supporting multiple layers and activations. |
| - **ClassificationHead**: A RoBERTa-style classification head for downstream tasks. |
| - **ClassificationHeadAnalysis**: An extended classification head that provides intermediate hidden states for analysis. |
| - **ClsDecoderAnalysis**: A classification decoder with support for hidden state extraction. |
| - **TrainingHead**: A dense layer with activation and normalization for training tasks. |
| - **AnnotationDecoderHead**: A lightweight decoder for annotation tasks with simplified weight initialization. |
| - **ExprDecoder**: A decoder for predicting gene expression values, with optional explicit zero probability prediction. |
| - **AffineExprDecoder**: A decoder for predicting gene expression values in an affine form (Ax + b), with support for |
| advanced features like adaptive bias and explicit zero probabilities. |
| |
| Dependencies: |
| - PyTorch: For defining and training neural network components. |
| - Transformers: For activation functions and integration with transformer-based models. |
| |
| Usage: |
| Import the desired classification or decoder head into your model: |
| ```python |
| from teddy.models.classification_heads import ClsDecoder, ClassificationHead |
| ``` |
| """ |
|
|
| from typing import Dict, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from transformers.activations import ACT2FN |
|
|
|
|
| class ClsDecoder(nn.Module): |
| """ |
| Decoder for classification task. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_cls: int, |
| nlayers: int = 1, |
| activation: callable = nn.ReLU, |
| ): |
| super().__init__() |
| |
| self._decoder = nn.ModuleList() |
| for _i in range(nlayers - 1): |
| self._decoder.append(nn.Linear(d_model, d_model)) |
| self._decoder.append(activation()) |
| self._decoder.append(nn.LayerNorm(d_model)) |
| self.out_layer = nn.Linear(d_model, n_cls) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| """ |
| Args: |
| x: Tensor, shape [batch_size, embsize] |
| """ |
| for layer in self._decoder: |
| x = layer(x) |
| return {"output": self.out_layer(x)} |
|
|
|
|
| class ClassificationHead(nn.Module): |
| """RoBERTa-style classification head""" |
|
|
| def __init__(self, config, n_cls, nlayers): |
| super().__init__() |
| self._decoder = nn.ModuleList() |
| self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU() |
|
|
| for _i in range(nlayers): |
| self._decoder.append(nn.Dropout(config.dropout)) |
| self._decoder.append(nn.Linear(config.d_model, config.d_model)) |
| self._decoder.append(self.activation) |
| self._decoder.append(nn.Dropout(config.dropout)) |
| self._decoder.append(nn.Linear(config.d_model, n_cls)) |
|
|
| def forward(self, x): |
| for module in self._decoder: |
| x = module(x) |
| return {"output": x} |
|
|
|
|
| class ClassificationHeadAnalysis(nn.Module): |
| """RoBERTa-style classification head""" |
|
|
| def __init__(self, config, n_cls, nlayers): |
| super().__init__() |
| self.dropout = nn.Dropout(config.dropout) |
| self._decoder = nn.ModuleList() |
| self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU() |
|
|
| for _i in range(nlayers): |
| self._decoder.append(self.dropout) |
| self._decoder.append(nn.Linear(config.d_model, config.d_model)) |
| self._decoder.append(self.activation) |
| self._decoder.append(self.dropout) |
| self._decoder.append(nn.Linear(config.d_model, n_cls)) |
|
|
| def forward(self, x): |
| hidden_states = [] |
| for module in self._decoder: |
| x = module(x) |
| if isinstance(module, nn.Linear): |
| hidden_states.append(x) |
| return {"output": x, "hidden_states": hidden_states} |
|
|
|
|
| class ClsDecoderAnalysis(nn.Module): |
| """ |
| Decoder for classification task. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_cls: int, |
| nlayers: int = 3, |
| activation: callable = nn.ReLU, |
| ): |
| super().__init__() |
| |
| self._decoder = nn.ModuleList() |
| for _i in range(nlayers - 1): |
| self._decoder.append(nn.Linear(d_model, d_model)) |
| self._decoder.append(activation()) |
| self._decoder.append(nn.LayerNorm(d_model)) |
| self.out_layer = nn.Linear(d_model, n_cls) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| """ |
| Args: |
| x: Tensor, shape [batch_size, embsize] |
| """ |
| hidden_states = [] |
| for layer in self._decoder: |
| x = layer(x) |
| hidden_states.append(x) |
| return {"output": self.out_layer(x), "hidden_states": hidden_states} |
|
|
|
|
| class TrainingHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.d_model, config.d_model) |
| self.activation = ACT2FN[config.layer_activation] |
| self.LayerNorm = nn.LayerNorm(config.d_model, config.layer_norm_eps) |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.activation(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states) |
| return hidden_states |
|
|
|
|
| class AnnotationDecoderHead(nn.Linear): |
| """Small class to make weight initialization easier""" |
|
|
| def __init__(self, d_model, n_token): |
| super().__init__(d_model, n_token, bias=False) |
|
|
|
|
| class ExprDecoder(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| explicit_zero_prob: bool = False, |
| use_batch_labels: bool = False, |
| ): |
| super().__init__() |
| d_in = d_model * 2 if use_batch_labels else d_model |
| self.fc = nn.Sequential( |
| nn.Linear(d_in, d_model), |
| nn.LeakyReLU(), |
| nn.Linear(d_model, d_model), |
| nn.LeakyReLU(), |
| nn.Linear(d_model, 1), |
| ) |
| self.explicit_zero_prob = explicit_zero_prob |
| if explicit_zero_prob: |
| self.zero_logit = nn.Sequential( |
| nn.Linear(d_in, d_model), |
| nn.LeakyReLU(), |
| nn.Linear(d_model, d_model), |
| nn.LeakyReLU(), |
| nn.Linear(d_model, 1), |
| ) |
|
|
| def forward(self, x: Tensor, values: Tensor = None) -> Dict[str, Tensor]: |
| """x is the output of the transformer, (batch, seq_len, d_model)""" |
| pred_value = self.fc(x).squeeze(-1) |
|
|
| if not self.explicit_zero_prob: |
| return {"pred": pred_value} |
| zero_logits = self.zero_logit(x).squeeze(-1) |
| zero_probs = torch.sigmoid(zero_logits) |
| return {"pred": pred_value, "zero_probs": zero_probs} |
| |
| |
| |
| |
| |
|
|
|
|
| class AffineExprDecoder(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| explicit_zero_prob: bool = False, |
| activation: Optional[str] = None, |
| tanh_coeff: bool = False, |
| adaptive_bias: bool = False, |
| ): |
| """ |
| Predict the expression value of each gene in an affine like form of Ax + b. |
| This decoder takes two ExprDecoder intrinsically to genrate the coefficient A and bias b. |
| |
| Args: |
| d_model: The embedding dimension. |
| explicit_zero_prob: If True, predict the probability of each gene being |
| zero. |
| activation: The activation function for the coefficient A and bias b. |
| tanh_coeff: If True, use tanh activation for the coefficient A. |
| adaptive_bias: If True, use a learnable bias for the bias b. |
| """ |
| super().__init__() |
| self.explicit_zero_prob = explicit_zero_prob |
| self.tanh_coeff = tanh_coeff |
| self.adaptive_bias = adaptive_bias |
| self.coeff_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob) |
| self.bias_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob) |
| self.activation = activation |
|
|
| if activation is not None: |
| |
| activation = activation.lower() |
| |
| activations_map = { |
| "gelu": "GELU", |
| "relu": "ReLU", |
| "tanh": "Tanh", |
| "sigmoid": "Sigmoid", |
| } |
| assert activation in activations_map, f"Unknown activation: {activation}" |
| assert hasattr(nn, activations_map[activation]), f"Unknown activation: {activation}" |
| self.activation = getattr(nn, activations_map[activation])() |
|
|
| def forward(self, x: Tensor, values: Tensor) -> Tensor: |
| """ |
| Args: |
| x: Tensor, shape [batch_size, seq_len, embsize] |
| values: Tensor, shape [batch_size, seq_len] |
| |
| Returns: |
| output Tensor of shape [batch_size, seq_len] |
| """ |
| coeff = self.coeff_decoder(x) |
| bias = self.bias_decoder(x) |
|
|
| if self.activation is not None: |
| coeff["pred"] = self.activation(coeff["pred"]) |
| bias["pred"] = self.activation(bias["pred"]) |
|
|
| |
| |
|
|
| if self.adaptive_bias: |
| |
| non_zero_value_mean = values.sum(dim=1, keepdim=True) / (values != 0).sum(dim=1, keepdim=True) |
| bias["pred"] = bias["pred"] * non_zero_value_mean |
|
|
| if self.explicit_zero_prob: |
| return { |
| "pred": coeff["pred"] * values + bias["pred"], |
| "zero_probs": coeff["zero_probs"], |
| } |
|
|
| return {"pred": coeff["pred"] * values + bias["pred"]} |
|
|