Spaces:
Running
Running
Upload service.py
Browse files- service.py +29 -3
service.py
CHANGED
|
@@ -124,8 +124,16 @@ class ChronosService:
|
|
| 124 |
|
| 125 |
def _extract_forecast(self, forecast: Any) -> tuple[list[float], list[float]]:
|
| 126 |
assert self._torch is not None
|
|
|
|
| 127 |
torch = self._torch
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
if hasattr(forecast, "detach"):
|
| 130 |
tensor = forecast.detach().cpu()
|
| 131 |
else:
|
|
@@ -139,9 +147,27 @@ class ChronosService:
|
|
| 139 |
mean_forecast = squeezed.tolist()
|
| 140 |
std_forecast = [0.0 for _ in mean_forecast]
|
| 141 |
else:
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
confidence: list[float] = []
|
| 147 |
for pred, std in zip(mean_forecast, std_forecast):
|
|
|
|
| 124 |
|
| 125 |
def _extract_forecast(self, forecast: Any) -> tuple[list[float], list[float]]:
|
| 126 |
assert self._torch is not None
|
| 127 |
+
assert self._pipeline is not None
|
| 128 |
torch = self._torch
|
| 129 |
|
| 130 |
+
if isinstance(forecast, (list, tuple)):
|
| 131 |
+
if not forecast:
|
| 132 |
+
raise RuntimeError("empty Chronos forecast output")
|
| 133 |
+
if len(forecast) != 1:
|
| 134 |
+
raise RuntimeError(f"unexpected Chronos batch size: {len(forecast)}")
|
| 135 |
+
forecast = forecast[0]
|
| 136 |
+
|
| 137 |
if hasattr(forecast, "detach"):
|
| 138 |
tensor = forecast.detach().cpu()
|
| 139 |
else:
|
|
|
|
| 147 |
mean_forecast = squeezed.tolist()
|
| 148 |
std_forecast = [0.0 for _ in mean_forecast]
|
| 149 |
else:
|
| 150 |
+
if squeezed.ndim == 2:
|
| 151 |
+
quantile_tensor = squeezed
|
| 152 |
+
else:
|
| 153 |
+
quantile_tensor = squeezed.reshape(-1, squeezed.shape[-2], squeezed.shape[-1]).mean(dim=0)
|
| 154 |
+
|
| 155 |
+
quantiles = list(getattr(self._pipeline, "quantiles", []))
|
| 156 |
+
if not quantiles:
|
| 157 |
+
median_idx = quantile_tensor.shape[0] // 2
|
| 158 |
+
mean_forecast = quantile_tensor[median_idx].tolist()
|
| 159 |
+
std_forecast = [0.0 for _ in mean_forecast]
|
| 160 |
+
else:
|
| 161 |
+
median_idx = min(range(len(quantiles)), key=lambda idx: abs(quantiles[idx] - 0.5))
|
| 162 |
+
lower_idx = min(range(len(quantiles)), key=lambda idx: abs(quantiles[idx] - 0.1))
|
| 163 |
+
upper_idx = min(range(len(quantiles)), key=lambda idx: abs(quantiles[idx] - 0.9))
|
| 164 |
+
mean_forecast = quantile_tensor[median_idx].tolist()
|
| 165 |
+
lower_forecast = quantile_tensor[lower_idx].tolist()
|
| 166 |
+
upper_forecast = quantile_tensor[upper_idx].tolist()
|
| 167 |
+
std_forecast = [
|
| 168 |
+
max(0.0, (float(upper) - float(lower)) / 2.0)
|
| 169 |
+
for lower, upper in zip(lower_forecast, upper_forecast)
|
| 170 |
+
]
|
| 171 |
|
| 172 |
confidence: list[float] = []
|
| 173 |
for pred, std in zip(mean_forecast, std_forecast):
|