| import logging |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torchmetrics |
| from matplotlib.figure import Figure |
|
|
| from src.data.containers import BatchTimeSeriesContainer |
| from src.data.frequency import Frequency |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def calculate_smape(y_true: np.ndarray, y_pred: np.ndarray) -> float: |
| """Calculate Symmetric Mean Absolute Percentage Error (SMAPE).""" |
| pred_tensor = torch.from_numpy(y_pred).float() |
| true_tensor = torch.from_numpy(y_true).float() |
| return torchmetrics.SymmetricMeanAbsolutePercentageError()(pred_tensor, true_tensor).item() |
|
|
|
|
| def _create_date_ranges( |
| start: np.datetime64 | pd.Timestamp | None, |
| frequency: Frequency | str | None, |
| history_length: int, |
| prediction_length: int, |
| ) -> tuple[pd.DatetimeIndex, pd.DatetimeIndex]: |
| """Create date ranges for history and future periods.""" |
| if start is not None and frequency is not None: |
| start_timestamp = pd.Timestamp(start) |
| pandas_freq = frequency.to_pandas_freq(for_date_range=True) |
|
|
| history_dates = pd.date_range(start=start_timestamp, periods=history_length, freq=pandas_freq) |
|
|
| if prediction_length > 0: |
| next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(pandas_freq) |
| future_dates = pd.date_range(start=next_timestamp, periods=prediction_length, freq=pandas_freq) |
| else: |
| future_dates = pd.DatetimeIndex([]) |
| else: |
| |
| history_dates = pd.date_range(end=pd.Timestamp.now(), periods=history_length, freq="D") |
|
|
| if prediction_length > 0: |
| future_dates = pd.date_range( |
| start=history_dates[-1] + pd.Timedelta(days=1), |
| periods=prediction_length, |
| freq="D", |
| ) |
| else: |
| future_dates = pd.DatetimeIndex([]) |
|
|
| return history_dates, future_dates |
|
|
|
|
| def _plot_single_channel( |
| ax: plt.Axes, |
| channel_idx: int, |
| history_dates: pd.DatetimeIndex, |
| future_dates: pd.DatetimeIndex, |
| history_values: np.ndarray, |
| future_values: np.ndarray | None = None, |
| predicted_values: np.ndarray | None = None, |
| lower_bound: np.ndarray | None = None, |
| upper_bound: np.ndarray | None = None, |
| ) -> None: |
| """Plot a single channel's time series data.""" |
| |
| ax.plot(history_dates, history_values[:, channel_idx], color="black", label="History") |
|
|
| |
| if future_values is not None: |
| ax.plot( |
| future_dates, |
| future_values[:, channel_idx], |
| color="blue", |
| label="Ground Truth", |
| ) |
|
|
| |
| if predicted_values is not None: |
| ax.plot( |
| future_dates, |
| predicted_values[:, channel_idx], |
| color="orange", |
| linestyle="--", |
| label="Prediction (Median)", |
| ) |
|
|
| |
| if lower_bound is not None and upper_bound is not None: |
| ax.fill_between( |
| future_dates, |
| lower_bound[:, channel_idx], |
| upper_bound[:, channel_idx], |
| color="orange", |
| alpha=0.2, |
| label="Uncertainty Band", |
| ) |
|
|
| ax.set_title(f"Channel {channel_idx + 1}") |
| ax.grid(True, which="both", linestyle="--", linewidth=0.5) |
|
|
|
|
| def _setup_figure(num_channels: int) -> tuple[Figure, list[plt.Axes]]: |
| """Create and configure the matplotlib figure and axes.""" |
| fig, axes = plt.subplots(num_channels, 1, figsize=(15, 3 * num_channels), sharex=True) |
| if num_channels == 1: |
| axes = [axes] |
| return fig, axes |
|
|
|
|
| def _finalize_plot( |
| fig: Figure, |
| axes: list[plt.Axes], |
| title: str | None = None, |
| smape_value: float | None = None, |
| output_file: str | None = None, |
| show: bool = True, |
| ) -> None: |
| """Add legend, title, and save/show the plot.""" |
| |
| handles, labels = axes[0].get_legend_handles_labels() |
| fig.legend(handles, labels, loc="upper right") |
|
|
| |
| if title: |
| if smape_value is not None: |
| title = f"{title} | SMAPE: {smape_value:.4f}" |
| fig.suptitle(title, fontsize=16) |
|
|
| |
| plt.tight_layout(rect=[0, 0.03, 1, 0.95] if title else None) |
|
|
| |
| if output_file: |
| plt.savefig(output_file, dpi=300) |
| if show: |
| plt.show() |
| else: |
| plt.close(fig) |
|
|
|
|
| def plot_multivariate_timeseries( |
| history_values: np.ndarray, |
| future_values: np.ndarray | None = None, |
| predicted_values: np.ndarray | None = None, |
| start: np.datetime64 | pd.Timestamp | None = None, |
| frequency: Frequency | str | None = None, |
| title: str | None = None, |
| output_file: str | None = None, |
| show: bool = True, |
| lower_bound: np.ndarray | None = None, |
| upper_bound: np.ndarray | None = None, |
| ) -> Figure: |
| """Plot a multivariate time series with history, future, predictions, and uncertainty bands.""" |
| |
| smape_value = None |
| if predicted_values is not None and future_values is not None: |
| try: |
| smape_value = calculate_smape(future_values, predicted_values) |
| except Exception as e: |
| logger.warning(f"Failed to calculate SMAPE: {str(e)}") |
|
|
| |
| num_channels = history_values.shape[1] |
| history_length = history_values.shape[0] |
| prediction_length = ( |
| predicted_values.shape[0] |
| if predicted_values is not None |
| else (future_values.shape[0] if future_values is not None else 0) |
| ) |
|
|
| |
| history_dates, future_dates = _create_date_ranges(start, frequency, history_length, prediction_length) |
|
|
| |
| fig, axes = _setup_figure(num_channels) |
|
|
| |
| for i in range(num_channels): |
| _plot_single_channel( |
| ax=axes[i], |
| channel_idx=i, |
| history_dates=history_dates, |
| future_dates=future_dates, |
| history_values=history_values, |
| future_values=future_values, |
| predicted_values=predicted_values, |
| lower_bound=lower_bound, |
| upper_bound=upper_bound, |
| ) |
|
|
| |
| _finalize_plot(fig, axes, title, smape_value, output_file, show) |
|
|
| return fig |
|
|
|
|
| def _extract_quantile_predictions( |
| predicted_values: np.ndarray, |
| model_quantiles: list[float], |
| ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: |
| """Extract median, lower, and upper bound predictions from quantile output.""" |
| try: |
| median_idx = model_quantiles.index(0.5) |
| lower_idx = model_quantiles.index(0.1) |
| upper_idx = model_quantiles.index(0.9) |
|
|
| median_preds = predicted_values[..., median_idx] |
| lower_bound = predicted_values[..., lower_idx] |
| upper_bound = predicted_values[..., upper_idx] |
|
|
| return median_preds, lower_bound, upper_bound |
| except (ValueError, IndexError): |
| logger.warning("Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles.") |
| median_preds = predicted_values[..., predicted_values.shape[-1] // 2] |
| return median_preds, None, None |
|
|
|
|
| def plot_from_container( |
| batch: BatchTimeSeriesContainer, |
| sample_idx: int, |
| predicted_values: np.ndarray | None = None, |
| model_quantiles: list[float] | None = None, |
| title: str | None = None, |
| output_file: str | None = None, |
| show: bool = True, |
| ) -> Figure: |
| """Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling.""" |
| |
| history_values = batch.history_values[sample_idx].cpu().numpy() |
| future_values = batch.future_values[sample_idx].cpu().numpy() |
|
|
| |
| if predicted_values is not None: |
| |
| if predicted_values.ndim >= 3 or ( |
| predicted_values.ndim == 2 and predicted_values.shape[0] > future_values.shape[0] |
| ): |
| sample_preds = predicted_values[sample_idx] |
| else: |
| sample_preds = predicted_values |
|
|
| |
| if model_quantiles: |
| median_preds, lower_bound, upper_bound = _extract_quantile_predictions(sample_preds, model_quantiles) |
| else: |
| median_preds = sample_preds |
| lower_bound = None |
| upper_bound = None |
| else: |
| median_preds = None |
| lower_bound = None |
| upper_bound = None |
|
|
| |
| return plot_multivariate_timeseries( |
| history_values=history_values, |
| future_values=future_values, |
| predicted_values=median_preds, |
| start=batch.start[sample_idx], |
| frequency=batch.frequency[sample_idx], |
| title=title, |
| output_file=output_file, |
| show=show, |
| lower_bound=lower_bound, |
| upper_bound=upper_bound, |
| ) |
|
|