File size: 1,613 Bytes
63d24c3
29ae4db
 
 
 
a094858
 
 
29ae4db
 
 
 
cf178b5
a094858
 
 
 
 
29ae4db
 
 
 
a094858
29ae4db
a094858
29ae4db
cf178b5
29ae4db
 
cf178b5
29ae4db
 
 
cf178b5
29ae4db
 
 
 
 
 
a094858
29ae4db
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import numpy as np
import pandas as pd
import os
from model.kronos import Kronos, KronosTokenizer, KronosPredictor

class EndpointHandler:
    def __init__(self, path=""):
        tokenizer_path = os.path.join(path, "tokenizer")
        tokenizer = KronosTokenizer.from_pretrained(tokenizer_path)
        model = Kronos.from_pretrained(path)
        self.predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)

    def __call__(self, data):
        inputs = data.get("inputs", [])
        parameters = data.get("parameters", {})
        prediction_length = parameters.get("prediction_length", 8)

        if len(inputs) == 0:
            return {"error": "No input data"}

        cols = ["open", "high", "low", "close", "volume"]
        if isinstance(inputs[0], list):
            df = pd.DataFrame(inputs, columns=cols[:len(inputs[0])])
        else:
            df = pd.DataFrame({"open": inputs, "high": inputs, "low": inputs, "close": inputs})

        if "volume" not in df.columns:
            df["volume"] = 0.0

        now = pd.Timestamp.now()
        x_timestamps = pd.date_range(end=now, periods=len(df), freq="15min")
        y_timestamps = pd.date_range(start=now + pd.Timedelta("15min"), periods=prediction_length, freq="15min")

        pred_df = self.predictor.predict(
            df, x_timestamps, y_timestamps,
            pred_len=prediction_length,
            T=1.0, top_k=0, top_p=0.9,
            sample_count=5, verbose=False
        )

        result = pred_df[["open", "high", "low", "close"]].values.tolist()
        return {"predictions": result}