Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Subset | |
| from sklearn.metrics import confusion_matrix | |
| from models import SingleTransformer | |
| from utils.helpers import create_multimodal_model | |
| def compute_confusion_matrices(id, model_config, fold_results, dataset, device): | |
| """ | |
| Get confusion matrices for each fold and aggregate them. | |
| Args: | |
| id (str): Model ID. | |
| model_config (dict): Model configuration. | |
| fold_results (list): List of dictionaries containing fold results. | |
| cls_valid_loader (torch.utils.data.DataLoader): Validation data loader. | |
| device (str): Device to use. | |
| Returns: | |
| list: List of confusion matrices for each fold and the aggregated confusion | |
| matrix. | |
| """ | |
| if id not in ['RNA', 'ATAC', 'Flux', 'Multi']: | |
| raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'") | |
| # Initialize an empty confusion matrix for aggregation | |
| agg_cm = np.zeros((2, 2), dtype=int) | |
| cms = [] | |
| for i, fold in enumerate(fold_results, 1): | |
| model_path = fold['best_model_path'] | |
| state_dict = torch.load(model_path) | |
| val_subset = Subset(dataset, fold['val_idx']) | |
| cls_valid_loader = DataLoader(val_subset, batch_size=32, shuffle=False) | |
| if id=='Multi': | |
| model = create_multimodal_model(model_config, device, use_mlm=False) | |
| else: | |
| model = SingleTransformer(id, **model_config).to(device) | |
| model.load_state_dict(state_dict, strict=True) | |
| model.eval() | |
| val_preds, val_labels = [], [] | |
| with torch.no_grad(): | |
| for inputs, bi, y in cls_valid_loader: | |
| if isinstance(inputs, list): | |
| rna= inputs[0].to(device) | |
| atac = inputs[1].to(device) | |
| flux = inputs[2].to(device) | |
| inputs = (rna, atac, flux) | |
| else: | |
| inputs = inputs.to(device) | |
| bi, y = bi.to(device), y.to(device) | |
| preds, _ = model(inputs, bi) | |
| preds = preds.cpu().numpy() | |
| val_preds.append(preds) | |
| val_labels.append(y.cpu().numpy()) | |
| val_preds = np.concatenate(val_preds).ravel() | |
| val_labels = np.concatenate(val_labels).ravel() | |
| binary_preds = (val_preds >= 0.5).astype(int) | |
| # print(f"Fold {i} Confusion Matrix:", val_preds) | |
| cm = confusion_matrix(val_labels, binary_preds) | |
| agg_cm += cm | |
| cms.append(cm) | |
| cms.append(agg_cm) | |
| return cms | |
| def compute_metrics_from_confusion_matrix(cm): | |
| """ | |
| Compute classification metrics from a confusion matrix. | |
| Args: | |
| cm (np.array): Confusion matrix. | |
| Returns: | |
| dict: Dictionary containing classification metrics. | |
| """ | |
| # in cm results of 5 folds are saved in a list. compute this metrics for each fold | |
| # then return the average of them and the std | |
| metrics_list = [] | |
| for fold_cm in cm[:-1]: # Exclude the aggregated confusion matrix | |
| tn, fp, fn, tp = fold_cm.ravel() | |
| precision = tp / (tp + fp) if tp + fp > 0 else 0 | |
| recall = tp / (tp + fn) if tp + fn > 0 else 0 | |
| f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 | |
| accuracy = (tp + tn) / (tp + tn + fp + fn) if tp + tn + fp + fn > 0 else 0 | |
| metrics_list.append({ | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'f1': f1, | |
| 'accuracy': accuracy, | |
| }) | |
| avg_metrics = { | |
| 'precision': np.mean([m['precision'] for m in metrics_list]), | |
| 'recall': np.mean([m['recall'] for m in metrics_list]), | |
| 'f1': np.mean([m['f1'] for m in metrics_list]), | |
| 'accuracy': np.mean([m['accuracy'] for m in metrics_list]), | |
| } | |
| std_metrics = { | |
| 'precision': np.std([m['precision'] for m in metrics_list]), | |
| 'recall': np.std([m['recall'] for m in metrics_list]), | |
| 'f1': np.std([m['f1'] for m in metrics_list]), | |
| 'accuracy': np.std([m['accuracy'] for m in metrics_list]), | |
| } | |
| return { | |
| 'average': avg_metrics, | |
| 'std': std_metrics, | |
| } |