Imp3rtinence commited on
Commit
cf178b5
·
1 Parent(s): a0a8201

Fix handler: use native Kronos model instead of ChronosPipeline

Browse files
Files changed (2) hide show
  1. handler.py +38 -11
  2. requirements.txt +3 -2
handler.py CHANGED
@@ -1,26 +1,53 @@
1
  import torch
2
- from chronos import ChronosPipeline
 
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
6
- self.pipeline = ChronosPipeline.from_pretrained(
7
- path,
8
- device_map="auto",
9
- torch_dtype=torch.float32,
 
 
 
 
 
 
 
 
10
  )
11
 
 
 
 
 
12
  def __call__(self, data):
13
  inputs = data.get("inputs", [])
14
  parameters = data.get("parameters", {})
15
  prediction_length = parameters.get("prediction_length", 8)
16
 
17
  if isinstance(inputs[0], list):
18
- closes = [candle[3] for candle in inputs]
19
  else:
20
- closes = inputs
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- context = torch.tensor([closes], dtype=torch.float32)
23
- forecast = self.pipeline.predict(context, prediction_length=prediction_length, num_samples=20)
24
- median_forecast = forecast.median(dim=1).values[0].tolist()
25
 
26
- return {"predictions": median_forecast}
 
 
 
 
1
  import torch
2
+ import json
3
+ from safetensors.torch import load_file
4
+ from model.kronos import Kronos
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ with open(f"{path}/config.json", "r") as f:
9
+ config = json.load(f)
10
+
11
+ self.model = Kronos(
12
+ input_dim=config.get("input_dim", 5),
13
+ d_model=config.get("d_model", 256),
14
+ nhead=config.get("nhead", 8),
15
+ num_layers=config.get("num_layers", 6),
16
+ dim_feedforward=config.get("dim_feedforward", 1024),
17
+ max_seq_len=config.get("max_seq_len", 512),
18
+ output_dim=config.get("output_dim", 5),
19
+ dropout=config.get("dropout", 0.1),
20
  )
21
 
22
+ weights = load_file(f"{path}/model.safetensors")
23
+ self.model.load_state_dict(weights)
24
+ self.model.eval()
25
+
26
  def __call__(self, data):
27
  inputs = data.get("inputs", [])
28
  parameters = data.get("parameters", {})
29
  prediction_length = parameters.get("prediction_length", 8)
30
 
31
  if isinstance(inputs[0], list):
32
+ ohlcv = inputs
33
  else:
34
+ ohlcv = [[v, v, v, v, 0] for v in inputs]
35
+
36
+ tensor = torch.tensor([ohlcv], dtype=torch.float32)
37
+
38
+ last_close = ohlcv[-1][3]
39
+ if last_close > 0:
40
+ tensor = tensor / last_close
41
+
42
+ with torch.no_grad():
43
+ output = self.model(tensor)
44
+
45
+ predicted = output[0, -prediction_length:, :].tolist()
46
 
47
+ if last_close > 0:
48
+ predicted = [[v * last_close for v in candle] for candle in predicted]
 
49
 
50
+ return {
51
+ "predictions": predicted,
52
+ "prediction_length": prediction_length,
53
+ }
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
- torch
2
- chronos-forecasting
 
 
1
+ torch>=2.0.0
2
+ safetensors
3
+ einops