Spaces:
Running
Running
Upload service.py
Browse files- service.py +10 -2
service.py
CHANGED
|
@@ -95,7 +95,11 @@ class ChronosService:
|
|
| 95 |
assert self._pipeline is not None
|
| 96 |
|
| 97 |
torch = self._torch
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
forecast = self._pipeline.predict(
|
| 100 |
context,
|
| 101 |
prediction_length=self.max_horizon_step,
|
|
@@ -127,7 +131,11 @@ class ChronosService:
|
|
| 127 |
else:
|
| 128 |
tensor = torch.as_tensor(forecast)
|
| 129 |
|
| 130 |
-
if tensor.ndim ==
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
samples = tensor[0]
|
| 132 |
mean_forecast = samples.mean(dim=0).tolist()
|
| 133 |
std_forecast = samples.std(dim=0).tolist()
|
|
|
|
| 95 |
assert self._pipeline is not None
|
| 96 |
|
| 97 |
torch = self._torch
|
| 98 |
+
# Chronos-2 expects (n_series, n_variates, history_length) for tensor input.
|
| 99 |
+
context = torch.tensor(
|
| 100 |
+
close_prices[-self.max_context_length :],
|
| 101 |
+
dtype=torch.float32,
|
| 102 |
+
).reshape(1, 1, -1)
|
| 103 |
forecast = self._pipeline.predict(
|
| 104 |
context,
|
| 105 |
prediction_length=self.max_horizon_step,
|
|
|
|
| 131 |
else:
|
| 132 |
tensor = torch.as_tensor(forecast)
|
| 133 |
|
| 134 |
+
if tensor.ndim == 4:
|
| 135 |
+
samples = tensor[0, :, 0, :]
|
| 136 |
+
mean_forecast = samples.mean(dim=0).tolist()
|
| 137 |
+
std_forecast = samples.std(dim=0).tolist()
|
| 138 |
+
elif tensor.ndim == 3:
|
| 139 |
samples = tensor[0]
|
| 140 |
mean_forecast = samples.mean(dim=0).tolist()
|
| 141 |
std_forecast = samples.std(dim=0).tolist()
|