MaTaylor commited on
Commit
7708226
·
verified ·
1 Parent(s): 3c53771

Upload service.py

Browse files
Files changed (1) hide show
  1. service.py +29 -3
service.py CHANGED
@@ -124,8 +124,16 @@ class ChronosService:
124
 
125
  def _extract_forecast(self, forecast: Any) -> tuple[list[float], list[float]]:
126
  assert self._torch is not None
 
127
  torch = self._torch
128
 
 
 
 
 
 
 
 
129
  if hasattr(forecast, "detach"):
130
  tensor = forecast.detach().cpu()
131
  else:
@@ -139,9 +147,27 @@ class ChronosService:
139
  mean_forecast = squeezed.tolist()
140
  std_forecast = [0.0 for _ in mean_forecast]
141
  else:
142
- samples = squeezed.reshape(-1, squeezed.shape[-1])
143
- mean_forecast = samples.mean(dim=0).tolist()
144
- std_forecast = samples.std(dim=0).tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  confidence: list[float] = []
147
  for pred, std in zip(mean_forecast, std_forecast):
 
124
 
125
  def _extract_forecast(self, forecast: Any) -> tuple[list[float], list[float]]:
126
  assert self._torch is not None
127
+ assert self._pipeline is not None
128
  torch = self._torch
129
 
130
+ if isinstance(forecast, (list, tuple)):
131
+ if not forecast:
132
+ raise RuntimeError("empty Chronos forecast output")
133
+ if len(forecast) != 1:
134
+ raise RuntimeError(f"unexpected Chronos batch size: {len(forecast)}")
135
+ forecast = forecast[0]
136
+
137
  if hasattr(forecast, "detach"):
138
  tensor = forecast.detach().cpu()
139
  else:
 
147
  mean_forecast = squeezed.tolist()
148
  std_forecast = [0.0 for _ in mean_forecast]
149
  else:
150
+ if squeezed.ndim == 2:
151
+ quantile_tensor = squeezed
152
+ else:
153
+ quantile_tensor = squeezed.reshape(-1, squeezed.shape[-2], squeezed.shape[-1]).mean(dim=0)
154
+
155
+ quantiles = list(getattr(self._pipeline, "quantiles", []))
156
+ if not quantiles:
157
+ median_idx = quantile_tensor.shape[0] // 2
158
+ mean_forecast = quantile_tensor[median_idx].tolist()
159
+ std_forecast = [0.0 for _ in mean_forecast]
160
+ else:
161
+ median_idx = min(range(len(quantiles)), key=lambda idx: abs(quantiles[idx] - 0.5))
162
+ lower_idx = min(range(len(quantiles)), key=lambda idx: abs(quantiles[idx] - 0.1))
163
+ upper_idx = min(range(len(quantiles)), key=lambda idx: abs(quantiles[idx] - 0.9))
164
+ mean_forecast = quantile_tensor[median_idx].tolist()
165
+ lower_forecast = quantile_tensor[lower_idx].tolist()
166
+ upper_forecast = quantile_tensor[upper_idx].tolist()
167
+ std_forecast = [
168
+ max(0.0, (float(upper) - float(lower)) / 2.0)
169
+ for lower, upper in zip(lower_forecast, upper_forecast)
170
+ ]
171
 
172
  confidence: list[float] = []
173
  for pred, std in zip(mean_forecast, std_forecast):