seed_everything: 42 # ---------------------------- TRAINER ------------------------------------------- trainer: default_root_dir: "checkpoint/MambaUnet" precision: "16-mixed" min_epochs: 1 max_epochs: 100 accelerator: cuda # limit_train_batches: 10 devices: [2] # strategy: ddp num_nodes: 1 enable_progress_bar: true sync_batchnorm: True enable_checkpointing: True # debugging fast_dev_run: false logger: - class_path: pytorch_lightning.loggers.WandbLogger init_args: project: "MambaUnet_Seq_UCTransnet" name: "MambaUnet_Seq_UCTransnet_Wandb_test" save_dir: "checkpoint/MambaUnet/wandb_logs" log_model: False - class_path: pytorch_lightning.loggers.CSVLogger init_args: save_dir: "checkpoint/MambaUnet/logs" name: null version: null callbacks: - class_path: pytorch_lightning.callbacks.LearningRateMonitor init_args: logging_interval: "step" - class_path: pytorch_lightning.callbacks.ModelCheckpoint init_args: dirpath: "checkpoint/MambaUnet/checkpoints" monitor: "val/mse" # name of the logged metric which determines when model is improving mode: "min" # "max" means higher metric value is better, can be also "min" save_top_k: 1 # save k best models (determined by above metric) save_last: True # additionally always save model from last epoch verbose: False filename: "epoch_{epoch:03d}" auto_insert_metric_name: False - class_path: pytorch_lightning.callbacks.EarlyStopping init_args: monitor: "val/mse" # name of the logged metric which determines when model is improving mode: "min" # "max" means higher metric value is better, can be also "min" patience: 10 # how many validation epochs of not improving until training stops min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement - class_path: pytorch_lightning.callbacks.RichModelSummary init_args: max_depth: -1 - class_path: pytorch_lightning.callbacks.RichProgressBar # ---------------------------- MODEL ------------------------------------------- model: pretrained_path: "" beta_1: 0.9 beta_2: 0.99 lr: 5e-4 weight_decay: 1e-5 warmup_epochs: 10 max_epochs: 50 warmup_start_lr: 1e-8 eta_min: 1e-8 net: radar_timestamp: 4 # Initial Input for VSSM radar_data_types: 2 # Initial Input for VSSM sat_data_types: 8 # Fusion concat input for bottleneck layer # endregion # region Preprocess Embedding patch_size: 4 preprocess_depth: 2 # endregion # region Encoder & Decoder Path depths: [3,4] # Encoder and Decoder Depth dims: [32,64] # Channel after each VS Block size: 400 # Initial Size of H,W of Input(Radar - 400) # endregion # region Bottleneck bottleneck_depth: 8 # The length of a VS Block bottleneck_dim: 128 # The dimension after the encoding phase, will be doubled to fuse radar and sat bottleneck_size: 25 # The H,W size of the satellite image # endregion # region Output sat_output: 1 radar_output: 1 # endregion # region Config ape: True final_skip: False patch_norm: True drop_rate: 0. # must be in [0:1] drop_path_rate: 0.1 # must be in [0:1] use_checkpoint: False pretrained_path: None # endregion #skip connection upgrade_skip_connection: "UCTransNet" img_size: 400 vis: False channel_num: [8,128,256] patchSize: [40,10,5] KV_size: 392 num_layers: 3 num_heads: 3 attention_dropout_rate: 0.1 embeddings_dropout_rate: 0.1 dropout_rate: 0.1 expand_ratio: 3 cca: True # ---------------------------- DATA ------------------------------------------- data: dir_data: "/data/weather2025/NhaBe" time_points_radar: 1 time_points_sat: 1 sat_inp_vars: [ "land_sea_mask", "orography", "lattitude", "2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind", "total_precipitation", "2m_dewpoint_temperature" ] sat_out_vars: "total_precipitation" rad_inp_vars: [ "precipitation" ] rad_out_vars: "precipitation" hours_predicted: 3 batch_size: 16 num_workers: 4 pin_memory: False short_timestep: False rebuild_val: True augmentation: True servir_format: True # global_prior: True