| import gc |
| import gzip |
| import os |
| import pickle |
|
|
| from fire import Fire |
| import numpy as np |
| from omegaconf import OmegaConf |
|
|
| from ldcast.features import batch, patches, split, transform |
| from ldcast.models.autoenc import encoder, training |
|
|
| |
| file_dir = os.path.dirname("/data/data_WF/ldcast_precipitation/preprocess_data/") |
|
|
| def setup_data( |
| var="RZC", |
| |
| batch_size=32, |
| sampler_file=None, |
| num_timesteps=8, |
| chunks_file="../data/split_chunks.pkl.gz" |
| ): |
| variables = { |
| var: { |
| "sources": [var], |
| "timesteps": np.arange(num_timesteps), |
| } |
| } |
| predictors = [var] |
| target = var |
| raw_vars = [var] |
| raw = { |
| var: patches.load_all_patches( |
| |
| os.path.join(file_dir, f"{var}/"), var |
| ) |
| for var in raw_vars |
| } |
| |
|
|
| |
| |
| with gzip.open(os.path.join(file_dir, chunks_file), 'rb') as f: |
| chunks = pickle.load(f) |
| (raw, _) = split.train_valid_test_split(raw, var, chunks=chunks) |
| |
| variables[var]["transform"] = transform.default_rainrate_transform( |
| raw["train"][var]["scale"] |
| ) |
| |
| if sampler_file is None: |
| sampler_file = { |
| "train": "../cache/sampler_autoenc_train.pkl", |
| "valid": "../cache/sampler_autoenc_valid.pkl", |
| "test": "../cache/sampler_autoenc_test.pkl", |
| } |
| bins = np.exp(np.linspace(np.log(0.2), np.log(50), 10)) |
| datamodule = split.DataModule( |
| variables, raw, predictors, target, var, |
| sampling_bins=bins, batch_size=batch_size, |
| sampler_file=sampler_file, |
| valid_seed=1234, test_seed=2345 |
| ) |
| |
| gc.collect() |
| return datamodule |
|
|
|
|
| def setup_model( |
| model_dir=None |
| ): |
| enc = encoder.SimpleConvEncoder() |
| dec = encoder.SimpleConvDecoder() |
| (autoencoder, trainer) = training.setup_autoenc_training( |
| encoder=enc, |
| decoder=dec, |
| model_dir=model_dir |
| ) |
| gc.collect() |
| return (autoencoder, trainer) |
|
|
|
|
| def train( |
| var="RZC", |
| |
| batch_size=8, |
| sampler_file=None, |
| num_timesteps=8, |
| chunks_file="split_chunks.pkl.gz", |
| model_dir=None, |
| ckpt_path=None |
| ): |
| print("Loading data...") |
| datamodule = setup_data( |
| var=var, batch_size=batch_size, sampler_file=sampler_file, |
| num_timesteps=num_timesteps, chunks_file=chunks_file |
| ) |
|
|
| print("Setting up model...") |
| (model, trainer) = setup_model(model_dir=model_dir) |
|
|
| print("Starting training...") |
| trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) |
| trainer.test(model, datamodule=datamodule) |
|
|
| def main(config=None, **kwargs): |
| config = OmegaConf.load(config) if (config is not None) else {} |
| config.update(kwargs) |
| train(**config) |
|
|
|
|
| if __name__ == "__main__": |
| Fire(main) |
|
|