prediff_code / MambaUnet_servir.yaml
weatherforecast1024's picture
Upload folder using huggingface_hub
7667a87 verified
seed_everything: 42
# ---------------------------- TRAINER -------------------------------------------
trainer:
default_root_dir: "checkpoint/MambaUnet"
precision: "16-mixed"
min_epochs: 1
max_epochs: 100
accelerator: cuda
# limit_train_batches: 10
devices: [2]
# strategy: ddp
num_nodes: 1
enable_progress_bar: true
sync_batchnorm: True
enable_checkpointing: True
# debugging
fast_dev_run: false
logger:
- class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: "MambaUnet_Seq_UCTransnet"
name: "MambaUnet_Seq_UCTransnet_Wandb_test"
save_dir: "checkpoint/MambaUnet/wandb_logs"
log_model: False
- class_path: pytorch_lightning.loggers.CSVLogger
init_args:
save_dir: "checkpoint/MambaUnet/logs"
name: null
version: null
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: "step"
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: "checkpoint/MambaUnet/checkpoints"
monitor: "val/mse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionally always save model from last epoch
verbose: False
filename: "epoch_{epoch:03d}"
auto_insert_metric_name: False
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: "val/mse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
patience: 10 # how many validation epochs of not improving until training stops
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
- class_path: pytorch_lightning.callbacks.RichModelSummary
init_args:
max_depth: -1
- class_path: pytorch_lightning.callbacks.RichProgressBar
# ---------------------------- MODEL -------------------------------------------
model:
pretrained_path: ""
beta_1: 0.9
beta_2: 0.99
lr: 5e-4
weight_decay: 1e-5
warmup_epochs: 10
max_epochs: 50
warmup_start_lr: 1e-8
eta_min: 1e-8
net:
radar_timestamp: 4 # Initial Input for VSSM
radar_data_types: 2 # Initial Input for VSSM
sat_data_types: 8 # Fusion concat input for bottleneck layer
# endregion
# region Preprocess Embedding
patch_size: 4
preprocess_depth: 2
# endregion
# region Encoder & Decoder Path
depths: [3,4] # Encoder and Decoder Depth
dims: [32,64] # Channel after each VS Block
size: 400 # Initial Size of H,W of Input(Radar - 400)
# endregion
# region Bottleneck
bottleneck_depth: 8 # The length of a VS Block
bottleneck_dim: 128 # The dimension after the encoding phase, will be doubled to fuse radar and sat
bottleneck_size: 25 # The H,W size of the satellite image
# endregion
# region Output
sat_output: 1
radar_output: 1
# endregion
# region Config
ape: True
final_skip: False
patch_norm: True
drop_rate: 0. # must be in [0:1]
drop_path_rate: 0.1 # must be in [0:1]
use_checkpoint: False
pretrained_path: None
# endregion
#skip connection
upgrade_skip_connection: "UCTransNet"
img_size: 400
vis: False
channel_num: [8,128,256]
patchSize: [40,10,5]
KV_size: 392
num_layers: 3
num_heads: 3
attention_dropout_rate: 0.1
embeddings_dropout_rate: 0.1
dropout_rate: 0.1
expand_ratio: 3
cca: True
# ---------------------------- DATA -------------------------------------------
data:
dir_data: "/data/weather2025/NhaBe"
time_points_radar: 1
time_points_sat: 1
sat_inp_vars: [
"land_sea_mask",
"orography",
"lattitude",
"2m_temperature",
"10m_u_component_of_wind",
"10m_v_component_of_wind",
"total_precipitation",
"2m_dewpoint_temperature"
]
sat_out_vars: "total_precipitation"
rad_inp_vars: [
"precipitation"
]
rad_out_vars: "precipitation"
hours_predicted: 3
batch_size: 16
num_workers: 4
pin_memory: False
short_timestep: False
rebuild_val: True
augmentation: True
servir_format: True
# global_prior: True