MaTaylor commited on
Commit
3c53771
·
verified ·
1 Parent(s): 606e49b

Upload service.py

Browse files
Files changed (1) hide show
  1. service.py +9 -11
service.py CHANGED
@@ -131,19 +131,17 @@ class ChronosService:
131
  else:
132
  tensor = torch.as_tensor(forecast)
133
 
134
- if tensor.ndim == 4:
135
- samples = tensor[0, :, 0, :]
136
- mean_forecast = samples.mean(dim=0).tolist()
137
- std_forecast = samples.std(dim=0).tolist()
138
- elif tensor.ndim == 3:
139
- samples = tensor[0]
140
- mean_forecast = samples.mean(dim=0).tolist()
141
- std_forecast = samples.std(dim=0).tolist()
142
- elif tensor.ndim == 2:
143
- mean_forecast = tensor[0].tolist()
144
  std_forecast = [0.0 for _ in mean_forecast]
145
  else:
146
- raise RuntimeError(f"unexpected Chronos forecast shape: {tuple(tensor.shape)}")
 
 
147
 
148
  confidence: list[float] = []
149
  for pred, std in zip(mean_forecast, std_forecast):
 
131
  else:
132
  tensor = torch.as_tensor(forecast)
133
 
134
+ tensor = tensor.to(dtype=torch.float32)
135
+ squeezed = tensor.squeeze()
136
+ if squeezed.ndim == 0:
137
+ raise RuntimeError(f"unexpected Chronos forecast shape: {tuple(tensor.shape)}")
138
+ if squeezed.ndim == 1:
139
+ mean_forecast = squeezed.tolist()
 
 
 
 
140
  std_forecast = [0.0 for _ in mean_forecast]
141
  else:
142
+ samples = squeezed.reshape(-1, squeezed.shape[-1])
143
+ mean_forecast = samples.mean(dim=0).tolist()
144
+ std_forecast = samples.std(dim=0).tolist()
145
 
146
  confidence: list[float] = []
147
  for pred, std in zip(mean_forecast, std_forecast):