File size: 1,613 Bytes
63d24c3 29ae4db a094858 29ae4db cf178b5 a094858 29ae4db a094858 29ae4db a094858 29ae4db cf178b5 29ae4db cf178b5 29ae4db cf178b5 29ae4db a094858 29ae4db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import torch
import numpy as np
import pandas as pd
import os
from model.kronos import Kronos, KronosTokenizer, KronosPredictor
class EndpointHandler:
def __init__(self, path=""):
tokenizer_path = os.path.join(path, "tokenizer")
tokenizer = KronosTokenizer.from_pretrained(tokenizer_path)
model = Kronos.from_pretrained(path)
self.predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
def __call__(self, data):
inputs = data.get("inputs", [])
parameters = data.get("parameters", {})
prediction_length = parameters.get("prediction_length", 8)
if len(inputs) == 0:
return {"error": "No input data"}
cols = ["open", "high", "low", "close", "volume"]
if isinstance(inputs[0], list):
df = pd.DataFrame(inputs, columns=cols[:len(inputs[0])])
else:
df = pd.DataFrame({"open": inputs, "high": inputs, "low": inputs, "close": inputs})
if "volume" not in df.columns:
df["volume"] = 0.0
now = pd.Timestamp.now()
x_timestamps = pd.date_range(end=now, periods=len(df), freq="15min")
y_timestamps = pd.date_range(start=now + pd.Timedelta("15min"), periods=prediction_length, freq="15min")
pred_df = self.predictor.predict(
df, x_timestamps, y_timestamps,
pred_len=prediction_length,
T=1.0, top_k=0, top_p=0.9,
sample_count=5, verbose=False
)
result = pred_df[["open", "high", "low", "close"]].values.tolist()
return {"predictions": result}
|