figsr / configs /trainner-redux.yml
umzi's picture
Upload folder using huggingface_hub
4f763cc verified
# yaml-language-server: $schema=https://raw.githubusercontent.com/the-database/traiNNer-redux/refs/heads/master/schemas/redux-config.schema.json
#########################################################################################
# General Settings
# https://trainner-redux.readthedocs.io/en/latest/config_reference.html#top-level-options
#########################################################################################
name: 4x_figsr
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: true # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
use_compile: false # Enable torch.compile for generator. Takes time on startup to compile the model, but can speed up training after the model is compiled.
compile_mode: default # Mode to use with torch.compile. See https://docs.pytorch.org/docs/stable/generated/torch.compile.html for more info.
num_gpu: auto
# manual_seed: 1024 # Random seed for training, useful for removing randomness when testing the effect of different settings.
########################################################################################################################
# Dataset and Dataloader Settings
# https://trainner-redux.readthedocs.io/en/latest/config_reference.html#dataset-options-datasets-train-and-datasets-val
########################################################################################################################
datasets:
# Settings for the training dataset.
train:
name: Train Dataset
type: pairedimagedataset
# Path to the HR (high res) images in your training dataset. Specify one or multiple folders, separated by commas.
dataroot_gt: [datasets/BHI]
dataroot_lq: [datasets/BHI_lq]
# meta_info: data/meta_info/dataset1.txt
lq_size: 64 # During training, a square of this size is cropped from LR images. Larger is usually better but uses more VRAM. Previously gt_size, use lq_size = gt_size / scale to convert. Use multiple of 8 for best performance with AMP.
use_hflip: true # Randomly flip the images horizontally.
use_rot: true # Randomly rotate the images.
num_worker_per_gpu: 8
batch_size_per_gpu: 64 # recommended: 64 # Increasing stabilizes training but with diminishing returns. Use multiple of 8 for best performance with AMP.
accum_iter: 1 # Using values larger than 1 simulates higher batch size by trading performance for reduced VRAM usage. If accum_iter = 4 and batch_size_per_gpu = 6 then effective batch size = 4 * 6 = 24 but performance may be as much as 4 times as slow.
# Settings for your validation dataset (optional). These settings will
# be ignored if val_enabled is false in the Validation section below.
val:
name: Val Dataset
type: pairedimagedataset
dataroot_gt: [
datasets/val/dataset1/hr,
datasets/val/dataset1/hr2,
]
dataroot_lq: [
datasets/val/dataset1/lr,
datasets/val/dataset1/lr2
]
#####################################################################
# Network Settings
# https://trainner-redux.readthedocs.io/en/latest/arch_reference.html
#####################################################################
# Generator model settings
network_g:
type: FIGSR
#########################################################################################
# Pretrain and Resume Paths
# https://trainner-redux.readthedocs.io/en/latest/config_reference.html#path-options-path
#########################################################################################
path:
# pretrain_network_g: experiments/pretrained_models/pretrain.pth
param_key_g: ~
strict_load_g: true # Disable strict loading to partially load a pretrain model with a different scale
resume_state: ~
###########################################################################################
# Training Settings
# https://trainner-redux.readthedocs.io/en/latest/config_reference.html#train-options-train
###########################################################################################
train:
ema_decay: 0.999
ema_power: 0.75 # Gradually warm up ema decay when training from scratch
grad_clip: false # Gradient clipping allows more stable training when using higher learning rates.
# Optimizer for generator model
optim_g:
type: AdamW
lr: !!float 5e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [200000, 400000, 600000, 800000]
gamma: 0.5
total_iter: 1000000 # Total number of iterations.
warmup_iter: -1 # Gradually ramp up learning rates until this iteration, to stabilize early training. Use -1 to disable.
# Losses - for any loss set the loss_weight to 0 to disable it.
# https://trainner-redux.readthedocs.io/en/latest/loss_reference.html
losses:
# Charbonnier loss
- type: charbonnierloss
loss_weight: 1.0
##############################################################################################
# Validation
# https://trainner-redux.readthedocs.io/en/latest/config_reference.html#validation-options-val
##############################################################################################
val:
val_enabled: true # Whether to enable validations. If disabled, all validation settings below are ignored.
val_freq: 5000 # How often to run validations, in iterations.
save_img: true # Whether to save the validation images during validation, in the experiments/<name>/visualization folder.
tile_size: 0 # Tile size of input, reduce VRAM usage but slower inference. 0 to disable.
tile_overlap: 8 # Number of pixels to overlap tiles by, larger is slower but reduces tile seams.
metrics_enabled: true # Whether to run metrics calculations during validation.
metrics:
psnr:
type: calculate_psnr
crop_border: 4
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 4 # Whether to crop border during validation.
test_y_channel: true # Whether to convert to Y(CbCr) for validation.
#topiq:
#type: calculate_topiq
#lpips:
#type: calculate_lpips
#better: lower
#dists:
#type: calculate_dists
#better: lower
##############################################################################################
# Logging
# https://trainner-redux.readthedocs.io/en/latest/config_reference.html#logging-options-logger
##############################################################################################
logger:
print_freq: 100
save_checkpoint_freq: 5000
save_checkpoint_format: safetensors
use_tb_logger: true