MaTaylor commited on
Commit
606e49b
·
verified ·
1 Parent(s): 4cd18a1

Upload service.py

Browse files
Files changed (1) hide show
  1. service.py +10 -2
service.py CHANGED
@@ -95,7 +95,11 @@ class ChronosService:
95
  assert self._pipeline is not None
96
 
97
  torch = self._torch
98
- context = torch.tensor(close_prices[-self.max_context_length :], dtype=torch.float32)
 
 
 
 
99
  forecast = self._pipeline.predict(
100
  context,
101
  prediction_length=self.max_horizon_step,
@@ -127,7 +131,11 @@ class ChronosService:
127
  else:
128
  tensor = torch.as_tensor(forecast)
129
 
130
- if tensor.ndim == 3:
 
 
 
 
131
  samples = tensor[0]
132
  mean_forecast = samples.mean(dim=0).tolist()
133
  std_forecast = samples.std(dim=0).tolist()
 
95
  assert self._pipeline is not None
96
 
97
  torch = self._torch
98
+ # Chronos-2 expects (n_series, n_variates, history_length) for tensor input.
99
+ context = torch.tensor(
100
+ close_prices[-self.max_context_length :],
101
+ dtype=torch.float32,
102
+ ).reshape(1, 1, -1)
103
  forecast = self._pipeline.predict(
104
  context,
105
  prediction_length=self.max_horizon_step,
 
131
  else:
132
  tensor = torch.as_tensor(forecast)
133
 
134
+ if tensor.ndim == 4:
135
+ samples = tensor[0, :, 0, :]
136
+ mean_forecast = samples.mean(dim=0).tolist()
137
+ std_forecast = samples.std(dim=0).tolist()
138
+ elif tensor.ndim == 3:
139
  samples = tensor[0]
140
  mean_forecast = samples.mean(dim=0).tolist()
141
  std_forecast = samples.std(dim=0).tolist()