| import json |
| import logging |
| import os |
| from pathlib import Path |
|
|
| from gluonts.ev.metrics import ( |
| MAE, |
| MAPE, |
| MASE, |
| MSE, |
| MSIS, |
| ND, |
| NRMSE, |
| RMSE, |
| SMAPE, |
| MeanWeightedSumQuantileLoss, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
|
|
|
|
| |
| _MODULE_DIR = Path(__file__).parent.parent.parent |
| DATASET_PROPERTIES_PATH = _MODULE_DIR / "data" / "dataset_properties.json" |
|
|
|
|
| try: |
| with open(DATASET_PROPERTIES_PATH) as f: |
| DATASET_PROPERTIES = json.load(f) |
| except Exception as exc: |
| DATASET_PROPERTIES = {} |
| logger.warning( |
| "Could not load dataset properties from %s: %s. Domain and num_variates will fall back to defaults.", |
| DATASET_PROPERTIES_PATH, |
| exc, |
| ) |
|
|
|
|
| |
| SHORT_DATASETS = ( |
| "m4_yearly", |
| "m4_quarterly", |
| "m4_monthly", |
| "m4_weekly", |
| "m4_daily", |
| "m4_hourly", |
| "electricity/15T", |
| "electricity/H", |
| "electricity/D", |
| "electricity/W", |
| "solar/10T", |
| "solar/H", |
| "solar/D", |
| "solar/W", |
| "hospital", |
| "covid_deaths", |
| "us_births/D", |
| "us_births/M", |
| "us_births/W", |
| "saugeenday/D", |
| "saugeenday/M", |
| "saugeenday/W", |
| "temperature_rain_with_missing", |
| "kdd_cup_2018_with_missing/H", |
| "kdd_cup_2018_with_missing/D", |
| "car_parts_with_missing", |
| "restaurant", |
| "hierarchical_sales/D", |
| "hierarchical_sales/W", |
| "LOOP_SEATTLE/5T", |
| "LOOP_SEATTLE/H", |
| "LOOP_SEATTLE/D", |
| "SZ_TAXI/15T", |
| "SZ_TAXI/H", |
| "M_DENSE/H", |
| "M_DENSE/D", |
| "ett1/15T", |
| "ett1/H", |
| "ett1/D", |
| "ett1/W", |
| "ett2/15T", |
| "ett2/H", |
| "ett2/D", |
| "ett2/W", |
| "jena_weather/10T", |
| "jena_weather/H", |
| "jena_weather/D", |
| "bitbrains_fast_storage/5T", |
| "bitbrains_fast_storage/H", |
| "bitbrains_rnd/5T", |
| "bitbrains_rnd/H", |
| "bizitobs_application", |
| "bizitobs_service", |
| "bizitobs_l2c/5T", |
| "bizitobs_l2c/H", |
| ) |
|
|
| MED_LONG_DATASETS = ( |
| "electricity/15T", |
| "electricity/H", |
| "solar/10T", |
| "solar/H", |
| "kdd_cup_2018_with_missing/H", |
| "LOOP_SEATTLE/5T", |
| "LOOP_SEATTLE/H", |
| "SZ_TAXI/15T", |
| "M_DENSE/H", |
| "ett1/15T", |
| "ett1/H", |
| "ett2/15T", |
| "ett2/H", |
| "jena_weather/10T", |
| "jena_weather/H", |
| "bitbrains_fast_storage/5T", |
| "bitbrains_rnd/5T", |
| "bizitobs_application", |
| "bizitobs_service", |
| "bizitobs_l2c/5T", |
| "bizitobs_l2c/H", |
| ) |
|
|
| |
| ALL_DATASETS = list(dict.fromkeys(SHORT_DATASETS + MED_LONG_DATASETS)) |
|
|
|
|
| |
| TERMS = ("short", "medium", "long") |
|
|
|
|
| |
| PRETTY_NAMES = { |
| "saugeenday": "saugeen", |
| "temperature_rain_with_missing": "temperature_rain", |
| "kdd_cup_2018_with_missing": "kdd_cup_2018", |
| "car_parts_with_missing": "car_parts", |
| } |
|
|
|
|
| METRICS = ( |
| MSE(forecast_type="mean"), |
| MSE(forecast_type=0.5), |
| MAE(), |
| MASE(), |
| MAPE(), |
| SMAPE(), |
| MSIS(), |
| RMSE(), |
| NRMSE(), |
| ND(), |
| MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), |
| ) |
|
|
|
|
| STANDARD_METRIC_NAMES = ( |
| "MSE[mean]", |
| "MSE[0.5]", |
| "MAE[0.5]", |
| "MASE[0.5]", |
| "MAPE[0.5]", |
| "sMAPE[0.5]", |
| "MSIS", |
| "RMSE[mean]", |
| "NRMSE[mean]", |
| "ND[0.5]", |
| "mean_weighted_sum_quantile_loss", |
| ) |
|
|
|
|
| __all__ = [ |
| "ALL_DATASETS", |
| "DATASET_PROPERTIES", |
| "DATASET_PROPERTIES_PATH", |
| "MED_LONG_DATASETS", |
| "METRICS", |
| "PRETTY_NAMES", |
| "SHORT_DATASETS", |
| "STANDARD_METRIC_NAMES", |
| "TERMS", |
| ] |
|
|