Spaces:
Running
Running
Upload service.py
Browse files- service.py +9 -11
service.py
CHANGED
|
@@ -131,19 +131,17 @@ class ChronosService:
|
|
| 131 |
else:
|
| 132 |
tensor = torch.as_tensor(forecast)
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
mean_forecast = samples.mean(dim=0).tolist()
|
| 141 |
-
std_forecast = samples.std(dim=0).tolist()
|
| 142 |
-
elif tensor.ndim == 2:
|
| 143 |
-
mean_forecast = tensor[0].tolist()
|
| 144 |
std_forecast = [0.0 for _ in mean_forecast]
|
| 145 |
else:
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
|
| 148 |
confidence: list[float] = []
|
| 149 |
for pred, std in zip(mean_forecast, std_forecast):
|
|
|
|
| 131 |
else:
|
| 132 |
tensor = torch.as_tensor(forecast)
|
| 133 |
|
| 134 |
+
tensor = tensor.to(dtype=torch.float32)
|
| 135 |
+
squeezed = tensor.squeeze()
|
| 136 |
+
if squeezed.ndim == 0:
|
| 137 |
+
raise RuntimeError(f"unexpected Chronos forecast shape: {tuple(tensor.shape)}")
|
| 138 |
+
if squeezed.ndim == 1:
|
| 139 |
+
mean_forecast = squeezed.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
std_forecast = [0.0 for _ in mean_forecast]
|
| 141 |
else:
|
| 142 |
+
samples = squeezed.reshape(-1, squeezed.shape[-1])
|
| 143 |
+
mean_forecast = samples.mean(dim=0).tolist()
|
| 144 |
+
std_forecast = samples.std(dim=0).tolist()
|
| 145 |
|
| 146 |
confidence: list[float] = []
|
| 147 |
for pred, std in zip(mean_forecast, std_forecast):
|