| import os |
| import json |
| import tqdm |
| import torch |
| import numpy as np |
| import click |
| from datetime import datetime |
| import lightning.pytorch as pl |
| import sklearn.metrics as skm |
|
|
| from torch.utils.data import DataLoader |
| from torchvision.transforms import transforms as T |
| from torchvision.transforms._transforms_video import ToTensorVideo |
| from pytorchvideo.transforms import Normalize |
|
|
| |
| from full_model.rnn_dataset import SyntaxDataset |
| from full_model.rnn_model import SyntaxLightningModule |
| from metrics_visualization import visualize_final_syntax_plotly_multi |
|
|
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| print(f"DEVICE: {DEVICE}") |
|
|
|
|
| def safe_sample_std(values): |
| """Sample std (ddof=1). Если значение одно/пусто — 0.0.""" |
| arr = np.array(values, dtype=float) |
| if arr.size <= 1: |
| return 0.0 |
| return float(arr.std(ddof=1)) |
|
|
|
|
| def compute_metrics(y_true, y_pred, thr=22.0): |
| """R2, MAE, Pearson, MAPE, Mean_Recall.""" |
| y_true_arr = np.array(y_true, dtype=float) |
| y_pred_arr = np.array(y_pred, dtype=float) |
|
|
| r2 = float(skm.r2_score(y_true_arr, y_pred_arr)) |
| mae = float(skm.mean_absolute_error(y_true_arr, y_pred_arr)) |
|
|
| pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0 |
| mape = float(skm.mean_absolute_percentage_error(y_true_arr, y_pred_arr)) |
|
|
| y_true_bin = (y_true_arr >= thr).astype(int) |
| y_pred_bin = (y_pred_arr >= thr).astype(int) |
| unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin])) |
| mean_recall = float(np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))) \ |
| if len(unique_classes) > 1 else 0.0 |
|
|
| return r2, mae, pearson, mape, mean_recall |
|
|
|
|
| @click.command() |
| @click.option("-d", "--dataset-paths", multiple=True, |
| help="JSON с метаданными датасетов (относительно dataset_root).") |
| @click.option("-n", "--dataset-names", multiple=True, |
| help="Имена датасетов для метрик/графиков.") |
| @click.option("-p", "--postfixes", multiple=True, |
| help="Суффиксы для файлов предсказаний.") |
| @click.option("-r", "--dataset-root", type=click.Path(exists=True), |
| help="Корень датасета (где лежат JSON и DICOM).") |
| @click.option("-v", "--video-size", type=click.Tuple([int, int]), |
| help="Размер видео (H, W).") |
| @click.option("--frames-per-clip", |
| help="Количество кадров в клипе.") |
| @click.option("--num-workers", |
| help="Число DataLoader workers.") |
| @click.option("--seed", |
| help="Random seed.") |
| @click.option("--pt-weights-format", is_flag=True, |
| help="True → модели в .pt (torch.save), False → .ckpt (Lightning).") |
| @click.option("--use-scaling", is_flag=True, |
| help="Применить a*x+b scaling из JSON.") |
| @click.option("--scaling-file", |
| help="JSON с коэффициентами scaling (относительно dataset_root).") |
| @click.option("-e", "--ensemble-name", |
| help="Имя ансамбля в metrics.json.") |
| @click.option("-m", "--metrics-file", |
| help="JSON с метриками экспериментов.") |
| def main(dataset_paths, dataset_names, postfixes, dataset_root, video_size, |
| frames_per_clip, num_workers, seed, pt_weights_format, use_scaling, |
| scaling_file, ensemble_name, metrics_file): |
|
|
| pl.seed_everything(seed) |
| postfix_plotly = "Ensemble" |
|
|
| |
| model_paths = { |
| "left": [ |
| "full_model/checkpoints/leftBinSyntax_R3D_fold00_lstm_mean_post_best.pt", |
| "full_model/checkpoints/leftBinSyntax_R3D_fold01_lstm_mean_post_best.pt", |
| "full_model/checkpoints/leftBinSyntax_R3D_fold02_lstm_mean_post_best.pt", |
| "full_model/checkpoints/leftBinSyntax_R3D_fold03_lstm_mean_post_best.pt", |
| "full_model/checkpoints/leftBinSyntax_R3D_fold04_lstm_mean_post_best.pt", |
| ], |
| "right": [ |
| "full_model/checkpoints/rightBinSyntax_R3D_fold00_lstm_mean_post_best.pt", |
| "full_model/checkpoints/rightBinSyntax_R3D_fold01_lstm_mean_post_best.pt", |
| "full_model/checkpoints/rightBinSyntax_R3D_fold02_lstm_mean_post_best.pt", |
| "full_model/checkpoints/rightBinSyntax_R3D_fold03_lstm_mean_post_best.pt", |
| "full_model/checkpoints/rightBinSyntax_R3D_fold04_lstm_mean_post_best.pt", |
| ] |
| } |
|
|
| |
| scaling_params_dict = {} |
| if use_scaling: |
| postfix_plotly += "_scaled" |
| ensemble_name += "_scaled" |
| scaling_path = os.path.join(dataset_root, scaling_file) |
| if os.path.exists(scaling_path): |
| with open(scaling_path, "r") as f: |
| scaling_params_dict = json.load(f) |
| print(f"Loaded scaling from {scaling_path}") |
| else: |
| print(f"⚠️ Scaling file not found: {scaling_path}") |
|
|
| |
| ensemble_results = { |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| "use_scaling": use_scaling, |
| "pt_weights_format": pt_weights_format, |
| "datasets": {} |
| } |
|
|
| all_datasets, all_r2, all_recalls = {}, {}, {} |
|
|
| for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes): |
| |
| abs_dataset_path = os.path.join(dataset_root, dataset_path) |
| results_file = os.path.join(dataset_root, "coeffs", f"{postfix}.json") |
|
|
| |
| if os.path.exists(results_file): |
| print(f"[{postfix}] Loading from {results_file}") |
| with open(results_file, "r") as f: |
| data = json.load(f) |
| syntax_true = data["syntax_true"] |
| left_preds_all = data["left_preds"] |
| right_preds_all = data["right_preds"] |
| else: |
| print(f"[{postfix}] Computing predictions...") |
| left_preds_all, left_sids = run_artery( |
| abs_dataset_path, "left", model_paths["left"], |
| video_size, frames_per_clip, num_workers, pt_weights_format |
| ) |
| right_preds_all, right_sids = run_artery( |
| abs_dataset_path, "right", model_paths["right"], |
| video_size, frames_per_clip, num_workers, pt_weights_format |
| ) |
| assert left_sids == right_sids |
|
|
| with open(abs_dataset_path, "r") as f: |
| dataset = json.load(f) |
| syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset] |
|
|
| os.makedirs(os.path.dirname(results_file), exist_ok=True) |
| save_data = { |
| "syntax_true": syntax_true, |
| "left_preds": left_preds_all, |
| "right_preds": right_preds_all |
| } |
| with open(results_file, "w") as f: |
| json.dump(save_data, f) |
| print(f"[{postfix}] Saved to {results_file}") |
|
|
| |
| if use_scaling: |
| left_scaled_all, right_scaled_all = [], [] |
| for pred_list in left_preds_all: |
| scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val + |
| scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1] |
| for i, val in enumerate(pred_list)] |
| left_scaled_all.append(scaled) |
| for pred_list in right_preds_all: |
| scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val + |
| scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1] |
| for i, val in enumerate(pred_list)] |
| right_scaled_all.append(scaled) |
| else: |
| left_scaled_all, right_scaled_all = left_preds_all, right_preds_all |
|
|
| |
| syntax_pred = [max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)]))) |
| for l_list, r_list in zip(left_scaled_all, right_scaled_all)] |
|
|
| |
| r2, mae, pearson, mape, mean_recall = compute_metrics(syntax_true, syntax_pred) |
| print(f"[{postfix}] ENSEMBLE: R2={r2:.4f}, Pearson={pearson:.4f}, " |
| f"MAE={mae:.4f}, MAPE={mape:.4f}, Recall={mean_recall:.4f}") |
|
|
| |
| n_folds = len(left_scaled_all[0]) if left_scaled_all else 0 |
| fold_metrics = {metric: [] for metric in ["R2", "MAE", "Pearson", "MAPE", "Mean_Recall"]} |
| for k in range(n_folds): |
| pred_k = [max(0.0, l_list[k] + r_list[k]) |
| for l_list, r_list in zip(left_scaled_all, right_scaled_all)] |
| fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall = compute_metrics(syntax_true, pred_k) |
| for metric, value in zip(fold_metrics.keys(), |
| [fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall]): |
| fold_metrics[metric].append(value) |
|
|
| fold_summary = {k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v} |
| for k, v in fold_metrics.items()} |
|
|
| |
| all_datasets[dataset_name] = (syntax_true, syntax_pred) |
| all_r2[dataset_name] = r2 |
| all_recalls[dataset_name] = mean_recall |
|
|
| ensemble_results["datasets"][dataset_name] = { |
| |
| "R2": round(r2, 4), "MAE": round(mae, 4), |
| "Pearson": round(pearson, 4), "MAPE": round(mape, 4), |
| "Mean_Recall": round(mean_recall, 4), "N_samples": len(syntax_true), |
| |
| **{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()}, |
| **{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()}, |
| **{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()} |
| } |
|
|
| |
| metrics_path = os.path.join(dataset_root, metrics_file) |
| full_history = {} |
| if os.path.exists(metrics_path): |
| try: |
| with open(metrics_path, "r") as f: |
| full_history = json.load(f) |
| except json.JSONDecodeError: |
| print("⚠️ Metrics file corrupted. Creating new.") |
| |
| full_history[ensemble_name] = ensemble_results |
| with open(metrics_path, "w") as f: |
| json.dump(full_history, f, indent=4) |
| print(f"✅ Metrics saved: {metrics_path}") |
|
|
| |
| visualize_final_syntax_plotly_multi( |
| datasets=all_datasets, r2_values=all_r2, recall_values=all_recalls, |
| gt_row="ENSEMBLE", postfix=postfix_plotly |
| ) |
|
|
|
|
| def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip, |
| num_workers, pt_weights_format=False): |
| """Инференс для одной артерии (5 фолдов).""" |
| imagenet_mean = [0.485, 0.456, 0.406] |
| imagenet_std = [0.229, 0.224, 0.225] |
| test_transform = T.Compose([ |
| ToTensorVideo(), |
| T.Resize(size=video_size, antialias=True), |
| Normalize(mean=imagenet_mean, std=imagenet_std), |
| ]) |
|
|
| val_set = SyntaxDataset( |
| root=os.path.dirname(dataset_path), |
| meta=dataset_path, |
| train=False, |
| length=frames_per_clip, |
| label="", |
| artery=artery, |
| inference=True, |
| transform=test_transform |
| ) |
| val_loader = DataLoader(val_set, batch_size=1, num_workers=num_workers, |
| shuffle=False, pin_memory=True) |
| print(f"{artery} artery: {len(val_loader)} samples") |
|
|
| models = [] |
| for path in model_paths: |
| if not os.path.exists(path): |
| print(f"⚠️ Model not found: {path}") |
| continue |
| model = SyntaxLightningModule( |
| num_classes=2, lr=1e-5, variant="lstm_mean", |
| weight_decay=0.001, max_epochs=1, |
| pl_weight_path=path, pt_weights_format=pt_weights_format |
| ) |
| model.to(DEVICE) |
| model.eval() |
| models.append(model) |
| if not models: |
| raise RuntimeError(f"No models loaded for {artery}") |
|
|
| preds_all, sids = [], [] |
| with torch.no_grad(): |
| for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"): |
| if len(x.shape) == 1: |
| val_syntax_list = [0.0] * len(models) |
| else: |
| x = x.to(DEVICE) |
| val_syntax_list = [] |
| for model in models: |
| pred = model(x) |
| _, val_log = pred |
| val = float(torch.exp(val_log).cpu()) - 1 |
| val_syntax_list.append(val) |
| preds_all.append(val_syntax_list) |
| sids.append(sid[0]) |
|
|
| return preds_all, sids |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|