| import os |
| import json |
| import torch |
| import numpy as np |
| import click |
| import lightning.pytorch as pl |
| from lightning.pytorch.loggers import TensorBoardLogger |
| from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
| from lightning.pytorch.profilers import AdvancedProfiler, PyTorchProfiler |
|
|
| from pytorchvideo.transforms import Normalize, Permute, RandAugment |
| from torch.utils.data import DataLoader, WeightedRandomSampler |
| from torchvision.transforms import transforms as T |
| from torchvision.transforms._transforms_video import ToTensorVideo |
| from torchvision.transforms import InterpolationMode |
|
|
| from dataset import SyntaxDataset |
| from pl_model import SyntaxLightningModule |
|
|
| import warnings |
| warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`") |
|
|
| torch.set_float32_matmul_precision("medium") |
|
|
|
|
| """ |
| Скрипт обучения backbone (3D-ResNet) для предсказания SYNTAX score. |
| |
| Шаги: |
| 1) предварительное обучение (pretrain) — обучается только последний слой; |
| 2) полное дообучение (full) — fine-tuning всего backbone. |
| """ |
|
|
|
|
| |
| def get_transforms(video_size, imagenet_mean, imagenet_std, train=True): |
| interpolation_choices = [ |
| InterpolationMode.BILINEAR, |
| InterpolationMode.BICUBIC, |
| ] |
| if train: |
| return T.Compose([ |
| ToTensorVideo(), |
| Permute(dims=[1, 0, 2, 3]), |
| RandAugment(magnitude=10, num_layers=2), |
| T.RandomHorizontalFlip(), |
| Permute(dims=[1, 0, 2, 3]), |
| T.RandomChoice([ |
| T.Resize(size=video_size, interpolation=interp, antialias=True) |
| for interp in interpolation_choices |
| ]), |
| Normalize(mean=imagenet_mean, std=imagenet_std), |
| ]) |
| else: |
| return T.Compose([ |
| ToTensorVideo(), |
| T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True), |
| Normalize(mean=imagenet_mean, std=imagenet_std), |
| ]) |
|
|
|
|
| |
| def make_dataloader(dataset, batch_size, num_workers): |
| """ |
| Создаёт DataLoader; по умолчанию используем shuffle, |
| но можно легко переключиться на WeightedRandomSampler. |
| """ |
| sample_weights = dataset.get_sample_weights() |
| |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| |
| shuffle=True, |
| drop_last=True, |
| pin_memory=True, |
| ) |
|
|
|
|
| |
| def make_model(num_classes, video_shape, lr, weight_decay, max_epochs, weight_path=None): |
| """ |
| Обёртка над SyntaxLightningModule для единообразного создания модели |
| на этапах pretrain и full fine-tuning. |
| """ |
| model = SyntaxLightningModule( |
| num_classes=num_classes, |
| lr=lr, |
| weight_decay=weight_decay, |
| max_epochs=max_epochs, |
| weight_path=weight_path, |
| ) |
| return model |
|
|
|
|
| |
| def make_callbacks(artery: str, fold: int, phase: str): |
| """ |
| Возвращает набор callback'ов: |
| - LearningRateMonitor |
| - ModelCheckpoint с сохранением по наилучшему val_mae. |
| """ |
| lr_monitor = LearningRateMonitor(logging_interval="epoch") |
|
|
| if phase == "pre": |
| checkpoint = ModelCheckpoint( |
| monitor="val_mae", |
| save_top_k=1, |
| mode="min", |
| filename="model" + "-{epoch:02d}-{val_rmse:.3f}", |
| save_last=True, |
| ) |
| elif phase == "full": |
| checkpoint = ModelCheckpoint( |
| monitor="val_mae", |
| save_top_k=3, |
| mode="min", |
| filename="model" + "-{epoch:02d}-{val_rmse:.3f}", |
| save_last=True, |
| ) |
| else: |
| raise ValueError(f"Unknown phase '{phase}', expected 'pre' or 'full'") |
|
|
| return [lr_monitor, checkpoint] |
|
|
|
|
| |
| def make_trainer(max_epochs, logger_name, callbacks): |
| """ |
| Создаёт Lightning Trainer c TensorBoardLogger. |
| |
| Важно: пути к логам и устройствам можно адаптировать под свой кластер. |
| """ |
| logger = TensorBoardLogger( |
| save_dir="backbone_logs", |
| name=logger_name, |
| ) |
| trainer = pl.Trainer( |
| max_epochs=max_epochs, |
| accelerator="gpu", |
| devices=1, |
| strategy="ddp_find_unused_parameters_true", |
| precision="bf16-mixed", |
| callbacks=callbacks, |
| log_every_n_steps=10, |
| logger=logger, |
| ) |
| return trainer |
|
|
|
|
| @click.command() |
| @click.option( |
| "-r", |
| "--dataset-root", |
| type=click.Path(exists=True), |
| default=".", |
| required=True, |
| help="Путь к корню датасета (директория, внутри которой лежат JSON и DICOM).", |
| ) |
| @click.option("--fold", type=int, default=0, required=True, help="Номер фолда (0–4).") |
| @click.option( |
| "-a", |
| "--artery", |
| type=str, |
| default="right", |
| required=True, |
| help="Название артерии: 'left' или 'right'.", |
| ) |
| @click.option("-nc", "--num-classes", type=int, default=2, help="Число выходных каналов модели.") |
| @click.option("-b", "--batch-size", type=int, default=50, help="Размер batch.") |
| @click.option("-f", "--frames-per-clip", type=int, default=32, help="Количество кадров в клипе.") |
| @click.option( |
| "-v", |
| "--video-size", |
| type=click.Tuple([int, int]), |
| default=(256, 256), |
| help="Размер кадра (H, W).", |
| ) |
| @click.option("--max-epochs", type=int, default=10, help="Число эпох на этапе full fine-tuning.") |
| @click.option("--num-workers", type=int, default=8, help="Число воркеров для DataLoader.") |
| @click.option( |
| "--fast-dev-run", |
| is_flag=True, |
| default=False, |
| show_default=True, |
| help="Режим быстрой проверки пайплайна (1–2 батча).", |
| ) |
| @click.option("--seed", type=int, default=42, help="Сид для воспроизводимости.") |
| def main( |
| dataset_root, |
| fold, |
| artery, |
| num_classes, |
| batch_size, |
| frames_per_clip, |
| video_size, |
| max_epochs, |
| num_workers, |
| fast_dev_run, |
| seed, |
| ): |
| pl.seed_everything(seed) |
|
|
| artery = artery.lower() |
| artery_bin = {"left": 0, "right": 1}.get(artery) |
| if artery_bin is None: |
| raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'.") |
|
|
| imagenet_mean = [0.485, 0.456, 0.406] |
| imagenet_std = [0.229, 0.224, 0.225] |
|
|
| |
| |
| train_meta = os.path.join("folds", f"step2_fold{fold:02d}_train.json") |
| val_meta = os.path.join("folds", f"step2_fold{fold:02d}_eval.json") |
|
|
| train_set = SyntaxDataset( |
| root=dataset_root, |
| meta=train_meta, |
| train=True, |
| length=frames_per_clip, |
| label=f"syntax_{artery}", |
| artery_bin=artery_bin, |
| validation=False, |
| transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True), |
| ) |
|
|
| val_set = SyntaxDataset( |
| root=dataset_root, |
| meta=val_meta, |
| train=False, |
| length=frames_per_clip, |
| label=f"syntax_{artery}", |
| artery_bin=artery_bin, |
| validation=True, |
| transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False), |
| ) |
|
|
| train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers) |
| train_loader_post = make_dataloader(train_set, batch_size, num_workers) |
| val_loader = make_dataloader(val_set, 1, num_workers) |
|
|
| |
| x, *_ = next(iter(train_loader_pre)) |
| video_shape = x.shape[1:] |
|
|
| |
| callbacks_pre = make_callbacks(artery=artery, fold=fold, phase="pre") |
| callbacks_full = make_callbacks(artery=artery, fold=fold, phase="full") |
|
|
| |
| num_pre_epochs = 10 |
| model_pre = make_model( |
| num_classes=num_classes, |
| video_shape=video_shape, |
| lr=3e-4, |
| weight_decay=0.01, |
| max_epochs=num_pre_epochs, |
| ) |
| trainer_pre = make_trainer(num_pre_epochs, f"{artery}BinSyntax_R3D_pre_fold{fold:02d}", callbacks_pre) |
| trainer_pre.fit(model_pre, train_loader_pre, val_loader, ckpt_path=None) |
|
|
| |
| model_full = make_model( |
| num_classes=num_classes, |
| video_shape=video_shape, |
| lr=1e-4, |
| weight_decay=0.01, |
| max_epochs=max_epochs, |
| weight_path=trainer_pre.checkpoint_callback.last_model_path, |
| ) |
| trainer_full = make_trainer(max_epochs, f"{artery}BinSyntax_R3D_full_fold{fold:02d}", callbacks_full) |
| trainer_full.fit(model_full, train_loader_post, val_loader, ckpt_path=None) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|