| import logging |
|
|
| import numpy as np |
| import pandas as pd |
| from gluonts.model.forecast import QuantileForecast |
|
|
| from src.data.frequency import parse_frequency |
| from src.plotting.plot_timeseries import ( |
| plot_multivariate_timeseries, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _prepare_data_for_plotting(input_data: dict, label_data: dict, max_context_length: int): |
| history_values = np.asarray(input_data["target"], dtype=np.float32) |
| future_values = np.asarray(label_data["target"], dtype=np.float32) |
| start_period = input_data["start"] |
|
|
| def ensure_time_first(arr: np.ndarray) -> np.ndarray: |
| if arr.ndim == 1: |
| return arr.reshape(-1, 1) |
| elif arr.ndim == 2: |
| if arr.shape[0] < arr.shape[1]: |
| return arr.T |
| return arr |
| else: |
| return arr.reshape(arr.shape[-1], -1).T |
|
|
| history_values = ensure_time_first(history_values) |
| future_values = ensure_time_first(future_values) |
|
|
| if max_context_length is not None and history_values.shape[0] > max_context_length: |
| history_values = history_values[-max_context_length:] |
|
|
| |
| start_timestamp = ( |
| start_period.to_timestamp() if hasattr(start_period, "to_timestamp") else pd.Timestamp(start_period) |
| ) |
| return history_values, future_values, start_timestamp |
|
|
|
|
| def _extract_quantile_predictions( |
| forecast, |
| ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: |
| def ensure_2d_time_first(arr): |
| if arr is None: |
| return None |
| arr = np.asarray(arr) |
| if arr.ndim == 1: |
| return arr.reshape(-1, 1) |
| elif arr.ndim == 2: |
| return arr |
| else: |
| return arr.reshape(arr.shape[0], -1) |
|
|
| if isinstance(forecast, QuantileForecast): |
| try: |
| median_pred = forecast.quantile(0.5) |
| try: |
| lower_bound = forecast.quantile(0.1) |
| upper_bound = forecast.quantile(0.9) |
| except (KeyError, ValueError): |
| lower_bound = None |
| upper_bound = None |
| median_pred = ensure_2d_time_first(median_pred) |
| lower_bound = ensure_2d_time_first(lower_bound) |
| upper_bound = ensure_2d_time_first(upper_bound) |
| return median_pred, lower_bound, upper_bound |
| except Exception: |
| try: |
| median_pred = forecast.quantile(0.5) |
| median_pred = ensure_2d_time_first(median_pred) |
| return median_pred, None, None |
| except Exception: |
| return None, None, None |
| else: |
| try: |
| samples = forecast.samples |
| if samples.ndim == 1: |
| median_pred = samples |
| elif samples.ndim == 2: |
| if samples.shape[0] == 1: |
| median_pred = samples[0] |
| else: |
| median_pred = np.median(samples, axis=0) |
| elif samples.ndim == 3: |
| median_pred = np.median(samples, axis=0) |
| else: |
| median_pred = samples[0] if len(samples) > 0 else samples |
| median_pred = ensure_2d_time_first(median_pred) |
| return median_pred, None, None |
| except Exception: |
| return None, None, None |
|
|
|
|
| def _create_plot( |
| input_data: dict, |
| label_data: dict, |
| forecast, |
| dataset_full_name: str, |
| dataset_freq: str, |
| max_context_length: int, |
| title: str | None = None, |
| ): |
| try: |
| history_values, future_values, start_timestamp = _prepare_data_for_plotting( |
| input_data, label_data, max_context_length |
| ) |
| median_pred, lower_bound, upper_bound = _extract_quantile_predictions(forecast) |
| if median_pred is None: |
| logger.warning(f"Could not extract predictions for {dataset_full_name}") |
| return None |
|
|
| def ensure_compatible_shape(pred_arr, target_arr): |
| if pred_arr is None: |
| return None |
| pred_arr = np.asarray(pred_arr) |
| target_arr = np.asarray(target_arr) |
| if pred_arr.ndim == 1: |
| pred_arr = pred_arr.reshape(-1, 1) |
| if target_arr.ndim == 1: |
| target_arr = target_arr.reshape(-1, 1) |
| if pred_arr.shape != target_arr.shape: |
| if pred_arr.shape[0] == target_arr.shape[0]: |
| if pred_arr.shape[1] == 1 and target_arr.shape[1] > 1: |
| pred_arr = np.broadcast_to(pred_arr, target_arr.shape) |
| elif pred_arr.shape[1] > 1 and target_arr.shape[1] == 1: |
| pred_arr = pred_arr[:, :1] |
| elif pred_arr.shape[1] == target_arr.shape[1]: |
| min_time = min(pred_arr.shape[0], target_arr.shape[0]) |
| pred_arr = pred_arr[:min_time] |
| else: |
| if pred_arr.T.shape == target_arr.shape: |
| pred_arr = pred_arr.T |
| else: |
| if pred_arr.size >= target_arr.shape[0]: |
| pred_arr = pred_arr.flatten()[: target_arr.shape[0]].reshape(-1, 1) |
| if target_arr.shape[1] > 1: |
| pred_arr = np.broadcast_to(pred_arr, target_arr.shape) |
| return pred_arr |
|
|
| median_pred = ensure_compatible_shape(median_pred, future_values) |
| lower_bound = ensure_compatible_shape(lower_bound, future_values) |
| upper_bound = ensure_compatible_shape(upper_bound, future_values) |
|
|
| title = title or f"GIFT-Eval: {dataset_full_name}" |
| frequency = parse_frequency(dataset_freq) |
| fig = plot_multivariate_timeseries( |
| history_values=history_values, |
| future_values=future_values, |
| predicted_values=median_pred, |
| lower_bound=lower_bound, |
| upper_bound=upper_bound, |
| start=start_timestamp, |
| frequency=frequency, |
| title=title, |
| show=False, |
| ) |
| return fig |
| except Exception as e: |
| logger.warning(f"Failed to create plot for {dataset_full_name}: {e}") |
| return None |
|
|
|
|
| def create_plots_for_dataset( |
| forecasts: list, |
| test_data, |
| dataset_metadata, |
| max_plots: int, |
| max_context_length: int, |
| ) -> list[tuple[object, str]]: |
| input_data_list = list(test_data.input) |
| label_data_list = list(test_data.label) |
| num_plots = min(len(forecasts), max_plots) |
| logger.info(f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}") |
|
|
| figures_with_names: list[tuple[object, str]] = [] |
| for i in range(num_plots): |
| try: |
| forecast = forecasts[i] |
| input_data = input_data_list[i] |
| label_data = label_data_list[i] |
| title = ( |
| f"GIFT-Eval: {dataset_metadata.full_name} - Window {i + 1}/{num_plots}" |
| if hasattr(dataset_metadata, "full_name") |
| else f"Window {i + 1}/{num_plots}" |
| ) |
| fig = _create_plot( |
| input_data=input_data, |
| label_data=label_data, |
| forecast=forecast, |
| dataset_full_name=getattr(dataset_metadata, "full_name", "dataset"), |
| dataset_freq=getattr(dataset_metadata, "freq", "D"), |
| max_context_length=max_context_length, |
| title=title, |
| ) |
| if fig is not None: |
| filename = f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png" |
| figures_with_names.append((fig, filename)) |
| except Exception as e: |
| logger.warning(f"Error creating plot for window {i + 1}: {e}") |
| continue |
| return figures_with_names |
|
|