FateFormerExplorer / utils /helpers.py
kaveh's picture
init
ef814bf
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from models import SingleTransformer, MultiModalTransformer
import config
from data import create_dataset
def create_masked_input(input_tensor, mask_token, mask_prob=0.20):
"""
Creates a masked input tensor by randomly replacing elements with a mask token.
Args:
input_tensor (torch.Tensor): The input tensor to be masked.
mask_token: The token to be used for masking.
mask_prob (float, optional): The probability of masking an element. Defaults to 0.20.
Returns:
torch.Tensor: The masked input tensor.
torch.Tensor: A boolean mask indicating which elements were masked.
"""
mask = torch.rand(input_tensor.shape) < mask_prob
masked_input = input_tensor.clone()
masked_input[mask] = mask_token
return masked_input, mask
def get_max(adata):
"""
Get the maximum value in the data.
Args:
adata (list): A list of AnnData objects.
Returns:
float: The maximum value in the list data.
"""
assert(isinstance(adata, list)), "adata must be a list of AnnData objects."
x_s = []
for i in adata:
X = torch.tensor(i.X.toarray().copy())
x_s.append(np.array(X).flatten().max())
return max(x_s)
def get_token_embeddings(model, dataset, device):
"""
Get the token embeddings for the dataset.
Args:
model (torch.nn.Module): Model.
dataset (torch.utils.data.Dataset): Dataset.
device (str): Device to use.
Returns:
torch.Tensor: Embeddings.
"""
model.eval()
embeddings = []
loader = DataLoader(dataset, batch_size=32, shuffle=False)
with torch.no_grad():
for batch in loader:
if len(batch) == 3:
inputs, bi, _ = batch
elif len(batch) == 2:
inputs, bi = batch
if isinstance(inputs, list):
rna= inputs[0].to(device)
atac = inputs[1].to(device)
flux = inputs[2].to(device)
inputs = (rna, atac, flux)
else:
inputs = inputs.to(device)
bi = bi.to(device)
output = model(inputs, bi, return_embeddings=True)
embeddings.append(output.cpu().detach())
# Concatenate embeddings across batches
embeddings = torch.cat(embeddings, dim=0) # shape: (n_samples, seq_len, d_model)
return embeddings
def get_all_modalities_available_samples(dataset):
rna = dataset.rna_data
atac = dataset.atac_data
flux = dataset.flux_data
mask = (rna != 0).any(axis=1) & (atac != 0).any(axis=1) & (flux != 0).any(axis=1)
new_ds = create_dataset.MultiModalDataset((rna[mask], atac[mask], flux[mask]),
dataset.batch_no[mask],
dataset.labels[mask])
return new_ds
def separate_dataset(ds):
"""
Separate a dataset into two groups based on the labels.
Args:
ds (TensorDataset): Dataset.
Returns:
TensorDataset: Dataset with label 0.
TensorDataset: Dataset with label 1.
"""
X, b, y = ds.tensors
# Create masks for labels 0 and 1
mask_0 = (y == 0)
mask_1 = (y == 1)
# Filter the tensors based on the masks
X_0, b_0, y_0 = X[mask_0], b[mask_0], y[mask_0]
X_1, b_1, y_1 = X[mask_1], b[mask_1], y[mask_1]
# Create new datasets for each group
dataset_0 = TensorDataset(X_0, b_0, y_0) # Dataset with y == 0
dataset_1 = TensorDataset(X_1, b_1, y_1)
return dataset_0, dataset_1
def create_multimodal_model(model_config, device, use_mlm=False):
"""
Create a multimodal model.
Args:
model_config (dict): Model configuration.
device (str): Device to use.
use_mlm (bool, optional): Whether to use MLM pretraining. Defaults to False.
Returns:
MultiModalTransformer: Multimodal model.
"""
model_config_rna, model_config_atac, model_config_flux = model_config['RNA'], model_config['ATAC'], model_config['Flux']
share_config, model_config_multi = model_config['Share'], model_config['Multi']
rna_model = SingleTransformer("RNA", **model_config_rna, **share_config).to(device)
atac_model = SingleTransformer("ATAC", **model_config_atac, **share_config).to(device)
flux_model = SingleTransformer("Flux", **model_config_flux, **share_config).to(device)
if use_mlm:
rna_model.load_state_dict(torch.load(config.MLM_RNA_CKP), strict=False)
atac_model.load_state_dict(torch.load(config.MLM_ATAC_CKP), strict=False)
flux_model.load_state_dict(torch.load(config.MLM_FLUX_CKP), strict=False)
# print("Loaded MLM pretraining weights.: \n RNA: {}, ATAC: {}, Flux: {}".format(config.MLM_RNA_CKP, config.MLM_ATAC_CKP, config.MLM_FLUX_CKP))
model = MultiModalTransformer(rna_model, atac_model, flux_model, **model_config_multi).to(device)
return model