File size: 452 Bytes
7667a87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | from datamodule import WeatherForecastDataModuleOld
import yaml
if __name__ == "__main__":
with open('MambaUnet_servir.yaml','r') as f:
config = yaml.safe_load(f)
dm = WeatherForecastDataModuleOld(
**config['data']
)
dm.prepare_data()
dm.setup(stage=None)
train_loader = dm.train_dataloader()
for item in train_loader:
# N,C,T,H,W
# print(len(item))
print(item.shape)
break |