Spaces:
Running
Running
| """ | |
| Validation Results Analysis | |
| This module provides functions to create comprehensive DataFrames containing | |
| sample-level predictions, labels, and metadata from cross-validation results. | |
| """ | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Subset | |
| from utils.helpers import create_multimodal_model | |
| from models import SingleTransformer | |
| def get_sample_predictions_dataframe( | |
| model_type, | |
| multimodal_dataset, | |
| fold_results, | |
| model_config, | |
| device='cpu', | |
| batch_size=32, | |
| adata_rna=None, | |
| adata_atac=None, | |
| threshold=0.5 | |
| ): | |
| """ | |
| Creates a comprehensive DataFrame with sample-level predictions and metadata. | |
| Parameters | |
| ---------- | |
| model_type : str | |
| Type of model: 'Multi', 'RNA', 'ATAC', or 'Flux' | |
| multimodal_dataset : MultiModalDataset | |
| The multimodal dataset containing all samples | |
| fold_results : list | |
| List of fold result dictionaries from cross-validation | |
| model_config : dict | |
| Model configuration dictionary | |
| device : str, optional | |
| Device to run predictions on ('cpu', 'cuda', 'mps') | |
| batch_size : int, optional | |
| Batch size for predictions | |
| adata_rna : AnnData, optional | |
| RNA AnnData object for additional metadata | |
| adata_atac : AnnData, optional | |
| ATAC AnnData object for additional metadata | |
| threshold : float, optional | |
| Classification threshold for binary predictions (default: 0.5) | |
| Returns | |
| ------- | |
| pd.DataFrame | |
| DataFrame with columns: | |
| - ind: Sample index in the dataset | |
| - fold: Fold number | |
| - label_numeric: Actual label (0 or 1) | |
| - label: Actual label name ('dead-end' or 'reprogramming') | |
| - predicted_value: Predicted probability [0, 1] | |
| - predicted_class_numeric: Predicted class (0 or 1) | |
| - predicted_class: Predicted class name ('dead-end' or 'reprogramming') | |
| - correct: Whether prediction matches label | |
| - abs_error: Absolute error of prediction | |
| - modality: Available modalities for this sample (e.g., 'RAF', 'A', 'RF') | |
| - batch_no: Batch number | |
| - pct: Percentage metadata (if available) | |
| - clone_size: Clone size (if available) | |
| - clone_id: Clone ID (if available) | |
| - (additional RNA/ATAC metadata if adata objects provided) | |
| """ | |
| # Collect all predictions across folds | |
| all_predictions = [] | |
| all_labels = [] | |
| all_indices = [] | |
| all_folds = [] | |
| print(f"Processing {len(fold_results)} folds...") | |
| for fold_idx, fold in enumerate(fold_results): | |
| model_path = fold['best_model_path'] | |
| val_idx = fold['val_idx'] | |
| # Create validation subset | |
| val_subset = Subset(multimodal_dataset, val_idx) | |
| val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False) | |
| # Load model | |
| if model_type == 'Multi': | |
| model = create_multimodal_model(model_config, device, use_mlm=False) | |
| else: | |
| model = SingleTransformer(id=model_type, **model_config).to(device) | |
| # Load weights | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| # Get predictions | |
| fold_preds = [] | |
| fold_labels = [] | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| x, b, y = batch | |
| if isinstance(x, list): | |
| rna = x[0].to(device) | |
| atac = x[1].to(device) | |
| flux = x[2].to(device) | |
| x = (rna, atac, flux) | |
| else: | |
| x = x.to(device) | |
| b = b.to(device) | |
| # Get predictions | |
| preds, _ = model(x, b) | |
| preds = preds.squeeze() | |
| if preds.dim() == 0: | |
| preds = preds.unsqueeze(0) | |
| if y.dim() == 0: | |
| y = y.unsqueeze(0) | |
| fold_preds.extend(preds.cpu().numpy()) | |
| fold_labels.extend(y.numpy()) | |
| # Store results | |
| all_predictions.extend(fold_preds) | |
| all_labels.extend(fold_labels) | |
| all_indices.extend(val_idx) | |
| all_folds.extend([fold_idx + 1] * len(val_idx)) | |
| print(f" Fold {fold_idx + 1}: {len(val_idx)} samples processed") | |
| # Convert to arrays | |
| all_predictions = np.array(all_predictions) | |
| all_labels = np.array(all_labels) | |
| all_indices = np.array(all_indices) | |
| all_folds = np.array(all_folds) | |
| # Determine modality availability for each sample | |
| modalities = _get_modality_info(multimodal_dataset, all_indices) | |
| # Get additional metadata | |
| df_indices = multimodal_dataset.df_indics if hasattr(multimodal_dataset, 'df_indics') else None | |
| pcts = multimodal_dataset.pcts if hasattr(multimodal_dataset, 'pcts') else None | |
| label_names = multimodal_dataset.label_names if hasattr(multimodal_dataset, 'label_names') else None | |
| # Build base dataframe | |
| samples_data = [] | |
| for i, (idx, pred, label, fold) in enumerate(zip(all_indices, all_predictions, all_labels, all_folds)): | |
| # Compute error | |
| abs_error = abs(label - pred) | |
| # Determine if correct | |
| pred_class = int(pred >= threshold) | |
| is_correct = pred_class == int(label) | |
| # Get batch number | |
| batch_no = int(multimodal_dataset.batch_no[idx].item()) | |
| # Base sample info | |
| sample_info = { | |
| 'ind': idx, | |
| 'fold': fold, | |
| 'label_numeric': int(label), | |
| 'label': 'reprogramming' if label == 1 else 'dead-end', | |
| 'predicted_value': float(pred), | |
| 'predicted_class_numeric': pred_class, | |
| 'predicted_class': 'reprogramming' if pred_class == 1 else 'dead-end', | |
| 'correct': int(is_correct), | |
| 'abs_error': float(abs_error), | |
| 'modality': modalities[i], | |
| 'batch_no': batch_no, | |
| } | |
| # Add percentage if available | |
| if pcts is not None: | |
| sample_info['pct'] = float(pcts[idx]) | |
| # Add additional metadata from AnnData objects if available | |
| if df_indices is not None and (adata_rna is not None or adata_atac is not None): | |
| rna_id = df_indices.iloc[idx, 0] if df_indices.shape[1] > 0 else None | |
| atac_id = df_indices.iloc[idx, 1] if df_indices.shape[1] > 1 else None | |
| # Try to get metadata from RNA or ATAC | |
| metadata_added = False | |
| if adata_rna is not None and rna_id is not None and rna_id in adata_rna.obs.index: | |
| obs = adata_rna.obs.loc[rna_id] | |
| _add_obs_metadata(sample_info, obs) | |
| metadata_added = True | |
| if not metadata_added and adata_atac is not None and atac_id is not None and atac_id in adata_atac.obs.index: | |
| obs = adata_atac.obs.loc[atac_id] | |
| _add_obs_metadata(sample_info, obs) | |
| samples_data.append(sample_info) | |
| # Create DataFrame | |
| df_samples = pd.DataFrame(samples_data) | |
| # Sort by index for easier analysis | |
| df_samples = df_samples.sort_values('ind').reset_index(drop=True) | |
| print(f"\nTotal samples: {len(df_samples)}") | |
| print(f"Correct predictions: {df_samples['correct'].sum()} ({100 * df_samples['correct'].mean():.2f}%)") | |
| print(f"Mean absolute error: {df_samples['abs_error'].mean():.4f}") | |
| return df_samples | |
| def _get_modality_info(dataset, indices): | |
| """ | |
| Determine which modalities are available for each sample. | |
| Returns a list of modality strings: | |
| - 'RAF': RNA, ATAC, Flux all available | |
| - 'RA': RNA and ATAC available | |
| - 'RF': RNA and Flux available | |
| - 'AF': ATAC and Flux available | |
| - 'R': Only RNA available | |
| - 'A': Only ATAC available | |
| - 'F': Only Flux available | |
| """ | |
| modalities = [] | |
| for idx in indices: | |
| # Check if each modality has data | |
| has_rna = (dataset.rna_data[idx] != 0).any().item() | |
| has_atac = (dataset.atac_data[idx] != 0).any().item() | |
| has_flux = (dataset.flux_data[idx] != 0).any().item() | |
| # Build modality string | |
| modality = '' | |
| if has_rna: | |
| modality += 'R' | |
| if has_atac: | |
| modality += 'A' | |
| if has_flux: | |
| modality += 'F' | |
| modalities.append(modality if modality else 'None') | |
| return modalities | |
| def _add_obs_metadata(sample_info, obs): | |
| """Add metadata from AnnData obs to sample_info dictionary.""" | |
| metadata_fields = [ | |
| 'clone_size', 'clone_id', 'cells_RNA', 'cells_ATAC', | |
| 'cells_RNA_D3', 'cells_ATAC_D3', 'n_genes', 'phase', | |
| 'G2M_score', 'pct_counts_mt', 'total_counts' | |
| ] | |
| for field in metadata_fields: | |
| if field in obs: | |
| value = obs[field] | |
| # Handle different data types | |
| if pd.notna(value): | |
| if isinstance(value, (int, float, np.integer, np.floating)): | |
| sample_info[field] = value | |
| else: | |
| sample_info[field] = str(value) | |
| def summarize_by_modality(df_samples): | |
| """ | |
| Summarize prediction performance by modality. | |
| Parameters | |
| ---------- | |
| df_samples : pd.DataFrame | |
| DataFrame from get_sample_predictions_dataframe | |
| Returns | |
| ------- | |
| pd.DataFrame | |
| Summary statistics grouped by modality | |
| """ | |
| summary = df_samples.groupby('modality').agg({ | |
| 'ind': 'count', | |
| 'correct': 'mean', | |
| 'abs_error': 'mean', | |
| 'predicted_value': ['mean', 'std'] | |
| }).round(4) | |
| summary.columns = ['n_samples', 'accuracy', 'mean_abs_error', 'mean_pred', 'std_pred'] | |
| summary = summary.reset_index() | |
| summary = summary.sort_values('n_samples', ascending=False) | |
| return summary | |
| def summarize_by_fold(df_samples): | |
| """ | |
| Summarize prediction performance by fold. | |
| Parameters | |
| ---------- | |
| df_samples : pd.DataFrame | |
| DataFrame from get_sample_predictions_dataframe | |
| Returns | |
| ------- | |
| pd.DataFrame | |
| Summary statistics grouped by fold | |
| """ | |
| summary = df_samples.groupby('fold').agg({ | |
| 'ind': 'count', | |
| 'correct': 'mean', | |
| 'abs_error': 'mean', | |
| 'predicted_value': ['mean', 'std'] | |
| }).round(4) | |
| summary.columns = ['n_samples', 'accuracy', 'mean_abs_error', 'mean_pred', 'std_pred'] | |
| summary = summary.reset_index() | |
| return summary | |
| def get_misclassified_samples(df_samples): | |
| """ | |
| Get only misclassified samples. | |
| Parameters | |
| ---------- | |
| df_samples : pd.DataFrame | |
| DataFrame from get_sample_predictions_dataframe | |
| Returns | |
| ------- | |
| pd.DataFrame | |
| DataFrame containing only misclassified samples | |
| """ | |
| return df_samples[df_samples['correct'] == 0].copy() | |
| def get_samples_by_modality(df_samples, modality): | |
| """ | |
| Get samples filtered by modality. | |
| Parameters | |
| ---------- | |
| df_samples : pd.DataFrame | |
| DataFrame from get_sample_predictions_dataframe | |
| modality : str | |
| Modality string (e.g., 'RAF', 'A', 'RF') | |
| Returns | |
| ------- | |
| pd.DataFrame | |
| Filtered DataFrame | |
| """ | |
| return df_samples[df_samples['modality'] == modality].copy() | |
| if __name__ == "__main__": | |
| # Example usage | |
| print("This module provides functions to analyze validation results.") | |
| print("Main function: get_sample_predictions_dataframe()") |