""" Validation Results Analysis This module provides functions to create comprehensive DataFrames containing sample-level predictions, labels, and metadata from cross-validation results. """ import pandas as pd import numpy as np import torch from torch.utils.data import DataLoader, Subset from utils.helpers import create_multimodal_model from models import SingleTransformer def get_sample_predictions_dataframe( model_type, multimodal_dataset, fold_results, model_config, device='cpu', batch_size=32, adata_rna=None, adata_atac=None, threshold=0.5 ): """ Creates a comprehensive DataFrame with sample-level predictions and metadata. Parameters ---------- model_type : str Type of model: 'Multi', 'RNA', 'ATAC', or 'Flux' multimodal_dataset : MultiModalDataset The multimodal dataset containing all samples fold_results : list List of fold result dictionaries from cross-validation model_config : dict Model configuration dictionary device : str, optional Device to run predictions on ('cpu', 'cuda', 'mps') batch_size : int, optional Batch size for predictions adata_rna : AnnData, optional RNA AnnData object for additional metadata adata_atac : AnnData, optional ATAC AnnData object for additional metadata threshold : float, optional Classification threshold for binary predictions (default: 0.5) Returns ------- pd.DataFrame DataFrame with columns: - ind: Sample index in the dataset - fold: Fold number - label_numeric: Actual label (0 or 1) - label: Actual label name ('dead-end' or 'reprogramming') - predicted_value: Predicted probability [0, 1] - predicted_class_numeric: Predicted class (0 or 1) - predicted_class: Predicted class name ('dead-end' or 'reprogramming') - correct: Whether prediction matches label - abs_error: Absolute error of prediction - modality: Available modalities for this sample (e.g., 'RAF', 'A', 'RF') - batch_no: Batch number - pct: Percentage metadata (if available) - clone_size: Clone size (if available) - clone_id: Clone ID (if available) - (additional RNA/ATAC metadata if adata objects provided) """ # Collect all predictions across folds all_predictions = [] all_labels = [] all_indices = [] all_folds = [] print(f"Processing {len(fold_results)} folds...") for fold_idx, fold in enumerate(fold_results): model_path = fold['best_model_path'] val_idx = fold['val_idx'] # Create validation subset val_subset = Subset(multimodal_dataset, val_idx) val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False) # Load model if model_type == 'Multi': model = create_multimodal_model(model_config, device, use_mlm=False) else: model = SingleTransformer(id=model_type, **model_config).to(device) # Load weights state_dict = torch.load(model_path, map_location='cpu') model.load_state_dict(state_dict) model = model.to(device) model.eval() # Get predictions fold_preds = [] fold_labels = [] with torch.no_grad(): for batch in val_loader: x, b, y = batch if isinstance(x, list): rna = x[0].to(device) atac = x[1].to(device) flux = x[2].to(device) x = (rna, atac, flux) else: x = x.to(device) b = b.to(device) # Get predictions preds, _ = model(x, b) preds = preds.squeeze() if preds.dim() == 0: preds = preds.unsqueeze(0) if y.dim() == 0: y = y.unsqueeze(0) fold_preds.extend(preds.cpu().numpy()) fold_labels.extend(y.numpy()) # Store results all_predictions.extend(fold_preds) all_labels.extend(fold_labels) all_indices.extend(val_idx) all_folds.extend([fold_idx + 1] * len(val_idx)) print(f" Fold {fold_idx + 1}: {len(val_idx)} samples processed") # Convert to arrays all_predictions = np.array(all_predictions) all_labels = np.array(all_labels) all_indices = np.array(all_indices) all_folds = np.array(all_folds) # Determine modality availability for each sample modalities = _get_modality_info(multimodal_dataset, all_indices) # Get additional metadata df_indices = multimodal_dataset.df_indics if hasattr(multimodal_dataset, 'df_indics') else None pcts = multimodal_dataset.pcts if hasattr(multimodal_dataset, 'pcts') else None label_names = multimodal_dataset.label_names if hasattr(multimodal_dataset, 'label_names') else None # Build base dataframe samples_data = [] for i, (idx, pred, label, fold) in enumerate(zip(all_indices, all_predictions, all_labels, all_folds)): # Compute error abs_error = abs(label - pred) # Determine if correct pred_class = int(pred >= threshold) is_correct = pred_class == int(label) # Get batch number batch_no = int(multimodal_dataset.batch_no[idx].item()) # Base sample info sample_info = { 'ind': idx, 'fold': fold, 'label_numeric': int(label), 'label': 'reprogramming' if label == 1 else 'dead-end', 'predicted_value': float(pred), 'predicted_class_numeric': pred_class, 'predicted_class': 'reprogramming' if pred_class == 1 else 'dead-end', 'correct': int(is_correct), 'abs_error': float(abs_error), 'modality': modalities[i], 'batch_no': batch_no, } # Add percentage if available if pcts is not None: sample_info['pct'] = float(pcts[idx]) # Add additional metadata from AnnData objects if available if df_indices is not None and (adata_rna is not None or adata_atac is not None): rna_id = df_indices.iloc[idx, 0] if df_indices.shape[1] > 0 else None atac_id = df_indices.iloc[idx, 1] if df_indices.shape[1] > 1 else None # Try to get metadata from RNA or ATAC metadata_added = False if adata_rna is not None and rna_id is not None and rna_id in adata_rna.obs.index: obs = adata_rna.obs.loc[rna_id] _add_obs_metadata(sample_info, obs) metadata_added = True if not metadata_added and adata_atac is not None and atac_id is not None and atac_id in adata_atac.obs.index: obs = adata_atac.obs.loc[atac_id] _add_obs_metadata(sample_info, obs) samples_data.append(sample_info) # Create DataFrame df_samples = pd.DataFrame(samples_data) # Sort by index for easier analysis df_samples = df_samples.sort_values('ind').reset_index(drop=True) print(f"\nTotal samples: {len(df_samples)}") print(f"Correct predictions: {df_samples['correct'].sum()} ({100 * df_samples['correct'].mean():.2f}%)") print(f"Mean absolute error: {df_samples['abs_error'].mean():.4f}") return df_samples def _get_modality_info(dataset, indices): """ Determine which modalities are available for each sample. Returns a list of modality strings: - 'RAF': RNA, ATAC, Flux all available - 'RA': RNA and ATAC available - 'RF': RNA and Flux available - 'AF': ATAC and Flux available - 'R': Only RNA available - 'A': Only ATAC available - 'F': Only Flux available """ modalities = [] for idx in indices: # Check if each modality has data has_rna = (dataset.rna_data[idx] != 0).any().item() has_atac = (dataset.atac_data[idx] != 0).any().item() has_flux = (dataset.flux_data[idx] != 0).any().item() # Build modality string modality = '' if has_rna: modality += 'R' if has_atac: modality += 'A' if has_flux: modality += 'F' modalities.append(modality if modality else 'None') return modalities def _add_obs_metadata(sample_info, obs): """Add metadata from AnnData obs to sample_info dictionary.""" metadata_fields = [ 'clone_size', 'clone_id', 'cells_RNA', 'cells_ATAC', 'cells_RNA_D3', 'cells_ATAC_D3', 'n_genes', 'phase', 'G2M_score', 'pct_counts_mt', 'total_counts' ] for field in metadata_fields: if field in obs: value = obs[field] # Handle different data types if pd.notna(value): if isinstance(value, (int, float, np.integer, np.floating)): sample_info[field] = value else: sample_info[field] = str(value) def summarize_by_modality(df_samples): """ Summarize prediction performance by modality. Parameters ---------- df_samples : pd.DataFrame DataFrame from get_sample_predictions_dataframe Returns ------- pd.DataFrame Summary statistics grouped by modality """ summary = df_samples.groupby('modality').agg({ 'ind': 'count', 'correct': 'mean', 'abs_error': 'mean', 'predicted_value': ['mean', 'std'] }).round(4) summary.columns = ['n_samples', 'accuracy', 'mean_abs_error', 'mean_pred', 'std_pred'] summary = summary.reset_index() summary = summary.sort_values('n_samples', ascending=False) return summary def summarize_by_fold(df_samples): """ Summarize prediction performance by fold. Parameters ---------- df_samples : pd.DataFrame DataFrame from get_sample_predictions_dataframe Returns ------- pd.DataFrame Summary statistics grouped by fold """ summary = df_samples.groupby('fold').agg({ 'ind': 'count', 'correct': 'mean', 'abs_error': 'mean', 'predicted_value': ['mean', 'std'] }).round(4) summary.columns = ['n_samples', 'accuracy', 'mean_abs_error', 'mean_pred', 'std_pred'] summary = summary.reset_index() return summary def get_misclassified_samples(df_samples): """ Get only misclassified samples. Parameters ---------- df_samples : pd.DataFrame DataFrame from get_sample_predictions_dataframe Returns ------- pd.DataFrame DataFrame containing only misclassified samples """ return df_samples[df_samples['correct'] == 0].copy() def get_samples_by_modality(df_samples, modality): """ Get samples filtered by modality. Parameters ---------- df_samples : pd.DataFrame DataFrame from get_sample_predictions_dataframe modality : str Modality string (e.g., 'RAF', 'A', 'RF') Returns ------- pd.DataFrame Filtered DataFrame """ return df_samples[df_samples['modality'] == modality].copy() if __name__ == "__main__": # Example usage print("This module provides functions to analyze validation results.") print("Main function: get_sample_predictions_dataframe()")