| import os |
| from typing import Any, Dict |
|
|
| import pandas as pd |
| import torch |
|
|
| from kronos import Kronos, KronosTokenizer, KronosPredictor |
|
|
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def _load_components(model_dir: str = "."): |
| """ |
| Load tokenizer, model, and predictor from a local directory. |
| |
| This is called once at module import time on HF Inference Endpoints. |
| """ |
| tokenizer = KronosTokenizer.from_pretrained(model_dir) |
| model = Kronos.from_pretrained(model_dir).to(DEVICE) |
|
|
| max_context = int(os.getenv("KRONOS_MAX_CONTEXT", "512")) |
|
|
| predictor = KronosPredictor( |
| model=model, |
| tokenizer=tokenizer, |
| device=DEVICE, |
| max_context=max_context, |
| ) |
|
|
| return tokenizer, model, predictor |
|
|
|
|
| TOKENIZER, MODEL, PREDICTOR = _load_components(".") |
|
|
|
|
| def predict(request: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Entry point for Hugging Face Inference Endpoints. |
| |
| Expected input format: |
| |
| { |
| "inputs": { |
| "df": [ |
| {"open": ..., "high": ..., "low": ..., "close": ...}, |
| ... |
| ], |
| "x_timestamp": [...], # list of ISO8601 strings or timestamps |
| "y_timestamp": [...], # list of ISO8601 strings or timestamps |
| "pred_len": 120, |
| "T": 1.0, # optional |
| "top_p": 0.9, # optional |
| "sample_count": 1 # optional |
| } |
| } |
| """ |
| inputs = request.get("inputs", request) |
|
|
| df = pd.DataFrame(inputs["df"]) |
| x_timestamp = pd.to_datetime(inputs["x_timestamp"]) |
| y_timestamp = pd.to_datetime(inputs["y_timestamp"]) |
|
|
| pred_len = int(inputs["pred_len"]) |
| T = float(inputs.get("T", 1.0)) |
| top_p = float(inputs.get("top_p", 0.9)) |
| sample_count = int(inputs.get("sample_count", 1)) |
|
|
| result_df = PREDICTOR.predict( |
| df=df, |
| x_timestamp=x_timestamp, |
| y_timestamp=y_timestamp, |
| pred_len=pred_len, |
| T=T, |
| top_p=top_p, |
| sample_count=sample_count, |
| ) |
|
|
| |
| return { |
| "predictions": result_df.to_dict(orient="records"), |
| } |
|
|