| import logging |
| import os |
|
|
| import numpy as np |
| import torch |
| import yaml |
| from src.data.containers import BatchTimeSeriesContainer |
| from src.models.model import TimeSeriesModel |
| from src.plotting.plot_timeseries import plot_from_container |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel: |
| """Load the TimeSeriesModel from config and checkpoint.""" |
| with open(config_path) as f: |
| config = yaml.safe_load(f) |
|
|
| model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device) |
| checkpoint = torch.load(model_path, map_location=device) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}") |
| return model |
|
|
|
|
| def plot_with_library( |
| container: BatchTimeSeriesContainer, |
| predictions_np: np.ndarray, |
| model_quantiles: list[float] | None, |
| output_dir: str = "outputs", |
| show_plots: bool = True, |
| save_plots: bool = True, |
| ): |
| os.makedirs(output_dir, exist_ok=True) |
| batch_size = container.batch_size |
| for i in range(batch_size): |
| output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None |
| plot_from_container( |
| batch=container, |
| sample_idx=i, |
| predicted_values=predictions_np, |
| model_quantiles=model_quantiles, |
| title=f"Sine Wave Time Series Prediction - Sample {i + 1}", |
| output_file=output_file, |
| show=show_plots, |
| ) |
|
|
|
|
| def run_inference_and_plot( |
| model: TimeSeriesModel, |
| container: BatchTimeSeriesContainer, |
| output_dir: str = "outputs", |
| use_bfloat16: bool = True, |
| ) -> None: |
| """Run model inference with optional bfloat16 and plot using shared utilities.""" |
| device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu" |
| autocast_enabled = use_bfloat16 and device_type == "cuda" |
| with ( |
| torch.no_grad(), |
| torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled), |
| ): |
| model_output = model(container) |
|
|
| preds_full = model_output["result"].to(torch.float32) |
| if hasattr(model, "scaler") and "scale_statistics" in model_output: |
| preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"]) |
|
|
| preds_np = preds_full.detach().cpu().numpy() |
| model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None |
| plot_with_library( |
| container=container, |
| predictions_np=preds_np, |
| model_quantiles=model_quantiles, |
| output_dir=output_dir, |
| show_plots=True, |
| save_plots=True, |
| ) |
|
|