| import torch |
| import numpy as np |
| import pandas as pd |
| import os |
| from model.kronos import Kronos, KronosTokenizer, KronosPredictor |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| tokenizer_path = os.path.join(path, "tokenizer") |
| tokenizer = KronosTokenizer.from_pretrained(tokenizer_path) |
| model = Kronos.from_pretrained(path) |
| self.predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) |
|
|
| def __call__(self, data): |
| inputs = data.get("inputs", []) |
| parameters = data.get("parameters", {}) |
| prediction_length = parameters.get("prediction_length", 8) |
|
|
| if len(inputs) == 0: |
| return {"error": "No input data"} |
|
|
| cols = ["open", "high", "low", "close", "volume"] |
| if isinstance(inputs[0], list): |
| df = pd.DataFrame(inputs, columns=cols[:len(inputs[0])]) |
| else: |
| df = pd.DataFrame({"open": inputs, "high": inputs, "low": inputs, "close": inputs}) |
|
|
| if "volume" not in df.columns: |
| df["volume"] = 0.0 |
|
|
| now = pd.Timestamp.now() |
| x_timestamps = pd.date_range(end=now, periods=len(df), freq="15min") |
| y_timestamps = pd.date_range(start=now + pd.Timedelta("15min"), periods=prediction_length, freq="15min") |
|
|
| pred_df = self.predictor.predict( |
| df, x_timestamps, y_timestamps, |
| pred_len=prediction_length, |
| T=1.0, top_k=0, top_p=0.9, |
| sample_count=5, verbose=False |
| ) |
|
|
| result = pred_df[["open", "high", "low", "close"]].values.tolist() |
| return {"predictions": result} |
|
|