File size: 452 Bytes
7667a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from datamodule import WeatherForecastDataModuleOld
import yaml 

if __name__ == "__main__":
    with open('MambaUnet_servir.yaml','r') as f:
        config = yaml.safe_load(f)
    dm = WeatherForecastDataModuleOld(
        **config['data']
    )
    dm.prepare_data()
    dm.setup(stage=None)
    train_loader = dm.train_dataloader() 
    for item in train_loader:
        # N,C,T,H,W
        # print(len(item))
        print(item.shape)
        break