| |
| """ |
| S2F training script. |
| Usage: |
| python -m training.train --data path/to/dataset --model single_cell --epochs 100 |
| python -m training.train --data path/to/dataset --model spheroid --epochs 50 |
| """ |
| import os |
| import sys |
| import argparse |
|
|
| S2F_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| if S2F_ROOT not in sys.path: |
| sys.path.insert(0, S2F_ROOT) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Train S2F model') |
| parser.add_argument('--data', required=True, help='Path to dataset (must have train/ and test/ subfolders)') |
| parser.add_argument('--model', choices=['single_cell', 'spheroid'], default='single_cell', |
| help='Model type: single_cell (with substrate) or spheroid') |
| parser.add_argument('--substrate', default=None, |
| help='Substrate name for single_cell when metadata not in dataset (e.g. fibroblasts_PDMS)') |
| parser.add_argument('--epochs', type=int, default=100) |
| parser.add_argument('--batch_size', type=int, default=4) |
| parser.add_argument('--img_size', type=int, default=1024) |
| parser.add_argument('--save_dir', default='ckp', help='Checkpoint save directory') |
| parser.add_argument('--g_lr', type=float, default=2e-4) |
| parser.add_argument('--d_lr', type=float, default=2e-4) |
| parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from') |
| parser.add_argument('--device', default='cuda') |
| parser.add_argument('--no_augment', action='store_true', help='Disable augmentations') |
| parser.add_argument('--use_force_consistency', action='store_true') |
| parser.add_argument('--force_target', choices=['mean', 'sum'], default='mean') |
| args = parser.parse_args() |
|
|
| from data.cell_dataset import prepare_data |
| from models.s2f_model import create_s2f_model |
| from training.s2f_trainer import train_s2f |
|
|
| use_settings = args.model == 'single_cell' |
| substrate = args.substrate or 'fibroblasts_PDMS' |
| return_metadata = use_settings |
|
|
| print(f"Loading data from {args.data} (model={args.model})") |
| train_loader, val_loader = prepare_data( |
| args.data, |
| batch_size=args.batch_size, |
| target_size=(args.img_size, args.img_size), |
| use_augmentations=not args.no_augment, |
| train_test_sep_folder=True, |
| return_metadata=return_metadata, |
| substrate=substrate if use_settings else None, |
| ) |
|
|
| in_channels = 3 if use_settings else 1 |
| model_type = 's2f' if use_settings else 's2f_spheroid' |
| generator, discriminator = create_s2f_model(in_channels=in_channels, model_type=model_type) |
|
|
| if args.resume: |
| ckpt = __import__('torch').load(args.resume, map_location='cpu', weights_only=False) |
| generator.load_state_dict(ckpt.get('generator_state_dict', ckpt), strict=True) |
| print(f"Resumed from {args.resume}") |
|
|
| history = train_s2f( |
| generator, discriminator, |
| train_loader, val_loader, |
| device=args.device, |
| num_epochs=args.epochs, |
| g_lr=args.g_lr, d_lr=args.d_lr, |
| save_dir=args.save_dir, |
| loaded_metadata=return_metadata, |
| use_settings=use_settings, |
| use_force_consistency=args.use_force_consistency, |
| force_consistency_target=args.force_target, |
| ) |
| print(f"Training complete. Checkpoints saved to {args.save_dir}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|