import torch from torch.utils.data import DataLoader, TensorDataset, random_split from torch.utils.data.dataset import Dataset from anndata import AnnData import pandas as pd import random import numpy as np def get_mlm_loaders(train_data, val_data, batch_size=32, batch_key='batch_no', data_dtype=torch.float32): if isinstance(train_data, AnnData) and \ isinstance(val_data, AnnData): X_train = torch.tensor(train_data.X.toarray().copy(), dtype=data_dtype) b_train = torch.tensor(train_data.obs[batch_key], dtype=torch.int32) X_val = torch.tensor(val_data.X.toarray().copy(), dtype=data_dtype) b_val = torch.tensor(val_data.obs[batch_key], dtype=torch.int32) elif isinstance(train_data, tuple) and \ isinstance(train_data[0], (pd.DataFrame)) and \ isinstance(val_data, (tuple)) and \ isinstance(val_data[0], (pd.DataFrame)): X_train = torch.tensor(train_data[0].values, dtype=data_dtype) b_train = torch.tensor(train_data[1], dtype=torch.int32) X_val = torch.tensor(val_data[0].values, dtype=data_dtype) b_val = torch.tensor(val_data[1], dtype=torch.int32) else: raise ValueError("Data must be an AnnData object or a tuple of (pd.DataFrame, list).") mlm_train_dataset = TensorDataset(X_train, b_train) mlm_train_loader = DataLoader(mlm_train_dataset, batch_size=batch_size, shuffle=True) mlm_val_dataset = TensorDataset(X_val, b_val) mlm_val_loader = DataLoader(mlm_val_dataset, batch_size=batch_size, shuffle=False) return mlm_train_loader, mlm_val_loader def get_cls_dataset(data, batch_key='batch_no', label_key='label', pct_key='pct', filter_pcts=50.0, data_dtype=torch.float32): if isinstance(data, AnnData): X = torch.tensor(data.X.toarray().copy(), dtype=data_dtype) y = torch.tensor([{'reprogramming':1, 'dead-end':0}[i] for i in list(data.obs[label_key])], dtype=torch.float32) b = torch.tensor(data.obs[batch_key], dtype=torch.int32) pcts = torch.tensor(data.obs[pct_key], dtype=torch.float32) X = X[pcts > filter_pcts] y = y[pcts > filter_pcts] b = b[pcts > filter_pcts] pcts = pcts[pcts > filter_pcts] feature_names = data.var_names.tolist() elif isinstance(data, tuple) and isinstance(data[0], pd.DataFrame): X = torch.tensor(data[0].values, dtype=data_dtype) y = torch.tensor([{'reprogramming':1, 'dead-end':0}[i] for i in list(data[1])], dtype=torch.float32) b = torch.tensor(data[2], dtype=torch.int32) pcts = torch.tensor(data[3], dtype=torch.float32) X = X[pcts > filter_pcts] y = y[pcts > filter_pcts] b = b[pcts > filter_pcts] pcts = pcts[pcts > filter_pcts] feature_names = data[0].columns.tolist() else: raise ValueError("Data must be an AnnData object or a tuple of (pd.DataFrame, list, list, list).") dataset = TensorDataset(X, b, y) return dataset, pcts, feature_names def get_pair_modalities(adata_rna, adata_atac, flux_df, include_unused_atacs=False, seed=42): """ Pair RNA, ATAC and Flux data based on clone IDs. Args: adata_rna (AnnData): RNA data. adata_atac (AnnData): ATAC data. flux_df (pd.DataFrame): Flux data. include_unused_atacs (bool): Include ATAC samples that do not have a paired RNA sample. Returns: tuple: - rna_data (pd.DataFrame): RNA data matched by clone IDs, with rows representing samples and columns representing gene expressions. - atac_data (pd.DataFrame): ATAC data matched by clone IDs, with rows representing samples and columns representing chromatin accessibility features. - flux_data (pd.DataFrame): Flux data matched by clone IDs, with rows representing samples and columns representing flux measurements. np.array: labels. np.array of labels. np.array: batch indices. np.array of batch indices. pd.DataFrame: indices. A DataFrame where each row contains the indices of matched RNA and ATAC samples. If no match is found for one modality, the corresponding value is None. np.array: pcts. Array of dominant fate percentages for each paired sample. """ # Create a dictionary to map ATAC clone IDs to their indices atac_clone_to_indices = {clone_id: [] for clone_id in adata_atac.obs['clone_id'].unique()} adata_atac.obs['index'] = adata_atac.obs.index grouped = adata_atac.obs.groupby('clone_id')['index'].apply(list) atac_clone_to_indices.update(grouped) rna_data, atac_data, flux_data, labels, batch_ind, indices, pcts = [], [], [], [], [], [], [] used_atac_indices = set() for rna_index, row in adata_rna.obs.iterrows(): clone_id = row['clone_id'] sibling_atac_indices = [idx for idx in atac_clone_to_indices.get(clone_id, []) if idx not in used_atac_indices] if sibling_atac_indices: random.seed(seed) atac_index = random.choice(sibling_atac_indices) # atac_index = sibling_atac_indices[0] used_atac_indices.add(atac_index) rna_sample = adata_rna[rna_index].X.toarray().flatten() if hasattr(adata_rna[rna_index].X, 'toarray') else adata_rna[rna_index].X atac_sample = adata_atac[atac_index].X.toarray().flatten() if hasattr(adata_atac[atac_index].X, 'toarray') else adata_atac[atac_index].X else: rna_sample = adata_rna[rna_index].X.toarray().flatten() if hasattr(adata_rna[rna_index].X, 'toarray') else adata_rna[rna_index].X atac_sample = np.zeros(adata_atac.shape[1]) # Fill with zeros if no ATAC pair is found flux_sample = flux_df.loc[rna_index].values label = row['label'] bt = row['batch_no'] pct = row['pct'] rna_data.append(rna_sample) atac_data.append(atac_sample) flux_data.append(flux_sample) labels.append(label) batch_ind.append(bt) pcts.append(pct) indices.append((rna_index, atac_index) if sibling_atac_indices else (rna_index, None)) if include_unused_atacs: all_atac_indices = set(adata_atac.obs.index) unused_atac_indices = sorted(list(all_atac_indices - used_atac_indices)) unused_atac_samples = adata_atac[list(unused_atac_indices)] for atac_index in unused_atac_indices: atac_sample = unused_atac_samples[atac_index].X.toarray().flatten() if hasattr(unused_atac_samples[atac_index].X, 'toarray') else unused_atac_samples[atac_index].X rna_sample = np.zeros(adata_rna.shape[1]) # Fill with zeros for RNA flux_sample = np.zeros(flux_df.shape[1]) # Fill with zeros for flux label = adata_atac.obs.loc[atac_index, 'label'] bt = adata_atac.obs.loc[atac_index, 'batch_no'] pct = adata_atac.obs.loc[atac_index, 'pct'] rna_data.append(rna_sample) atac_data.append(atac_sample) flux_data.append(flux_sample) labels.append(label) batch_ind.append(bt) pcts.append(pct) indices.append((None, atac_index)) rna_data = pd.DataFrame(rna_data, columns=adata_rna.var_names, index=indices) atac_data = pd.DataFrame(atac_data, columns=adata_atac.var_names, index=indices) flux_data = pd.DataFrame(flux_data, columns=flux_df.columns, index=indices) X_i = (rna_data, atac_data, flux_data) y_i = np.array(labels) b_i = np.array(batch_ind) indices = pd.DataFrame(np.array(indices), columns=["RNA", "ATAC"]) pcts = np.array(pcts) return X_i, y_i, b_i, indices, pcts class MultiModalDataset(Dataset): """ Multi-modal dataset for RNA, ATAC, and Flux data. Args: X (tuple): Tuple of (RNA, ATAC, Flux) data. batch_no (list): List of batch indices. labels (list): List of labels. """ def __init__(self, X, batch_no, labels, df_indics=None, pcts=None, label_names=None): if isinstance(X[0], pd.DataFrame): self.rna_data = torch.tensor(X[0].values, dtype=torch.int32) self.atac_data = torch.tensor(X[1].values, dtype=torch.float32) self.flux_data = torch.tensor(X[2].values, dtype=torch.float32) else: self.rna_data = torch.tensor(X[0], dtype=torch.int32) self.atac_data = torch.tensor(X[1], dtype=torch.float32) self.flux_data = torch.tensor(X[2], dtype=torch.float32) self.batch_no = torch.tensor(batch_no, dtype=torch.int32) self.labels = torch.tensor(labels, dtype=torch.float32) self.df_indics = df_indics self.pcts = pcts self.label_names = label_names def __len__(self): return len(self.labels) def get_df_indices(self): return self.df_indics def get_pcts(self): return self.pcts def get_label_names(self): return self.label_names def __getitem__(self, idx): rna_sample = self.rna_data[idx] atac_sample = self.atac_data[idx] flux_sample = self.flux_data[idx] batch_no = self.batch_no[idx] label = self.labels[idx] return (rna_sample, atac_sample, flux_sample), batch_no, label