File size: 5,745 Bytes
c701170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from __future__ import annotations

import math
import os
from statistics import mean
from typing import Any

from schemas import HealthResponse, PredictRequest, PredictResponse, PredictionItem


class ChronosService:
    """CPU-first HF Space service wrapper for Chronos.

    This scaffold is designed for HuggingFace free CPU Spaces and keeps the
    serving contract aligned with `tsf-bridge`. The default backend is a
    deterministic CPU baseline rather than real Chronos inference.
    """

    def __init__(self) -> None:
        self.model_id = "chronos"
        self.model_name = os.getenv(
            "CHRONOS_MODEL_NAME",
            "amazon/chronos-2",
        )
        self.backend = os.getenv("CHRONOS_BACKEND", "baseline_cpu").strip() or "baseline_cpu"
        self.device = "cpu"
        self.ready = True
        self.max_context_length = int(os.getenv("CHRONOS_MAX_CONTEXT_LENGTH", "512"))
        self.max_horizon_step = int(os.getenv("CHRONOS_MAX_HORIZON_STEP", "288"))
        self.confidence_floor = float(os.getenv("CHRONOS_CONFIDENCE_FLOOR", "0.16"))
        self.confidence_ceiling = float(os.getenv("CHRONOS_CONFIDENCE_CEILING", "0.80"))
        self.min_required_points = int(os.getenv("CHRONOS_MIN_REQUIRED_POINTS", "32"))

    def health(self) -> HealthResponse:
        return HealthResponse(
            status="ok",
            model=self.model_name,
            model_id=self.model_id,
            backend=self.backend,
            device=self.device,
            ready=self.ready,
            max_context_length=self.max_context_length,
            max_horizon_step=self.max_horizon_step,
        )

    def predict(self, payload: PredictRequest) -> PredictResponse:
        self._validate_request(payload)
        closes = payload.close_prices[-payload.context_length :]
        predictions = self._predict_with_baseline(closes, payload.horizons)
        return PredictResponse(model_id=self.model_id, predictions=predictions)

    def _validate_request(self, payload: PredictRequest) -> None:
        if payload.context_length > self.max_context_length:
            raise ValueError(
                f"context_length {payload.context_length} exceeds "
                f"CHRONOS_MAX_CONTEXT_LENGTH={self.max_context_length}"
            )
        if payload.context_length > len(payload.close_prices):
            raise ValueError("context_length must not exceed len(close_prices)")
        if len(payload.close_prices) < self.min_required_points:
            raise ValueError(
                f"at least {self.min_required_points} close prices are required "
                "for CPU baseline stability"
            )
        if any(step > self.max_horizon_step for step in payload.horizons):
            raise ValueError(
                f"horizons contain values above CHRONOS_MAX_HORIZON_STEP={self.max_horizon_step}"
            )

    def _predict_with_baseline(
        self, close_prices: list[float], horizons: list[int]
    ) -> list[PredictionItem]:
        last_price = close_prices[-1]
        short_window = close_prices[-min(10, len(close_prices)) :]
        mid_window = close_prices[-min(24, len(close_prices)) :]
        long_window = close_prices[-min(64, len(close_prices)) :]

        short_mean = mean(short_window)
        mid_mean = mean(mid_window)
        long_mean = mean(long_window)
        momentum = 0.0 if short_mean == 0 else (last_price - short_mean) / short_mean
        mean_reversion = 0.0 if long_mean == 0 else (mid_mean - long_mean) / long_mean
        local_trend = self._slope(mid_window)

        predictions: list[PredictionItem] = []
        for step in horizons:
            horizon_scale = min(1.0, math.log(step + 1.0) / 3.8)
            expected_return = momentum * 0.35 + mean_reversion * 0.30 + local_trend * 0.35
            expected_return *= horizon_scale

            pred_price = max(0.00000001, last_price * (1.0 + expected_return))
            confidence = self._confidence(close_prices, step, abs(expected_return))
            predictions.append(
                PredictionItem(
                    step=step,
                    pred_price=round(pred_price, 8),
                    pred_confidence=round(confidence, 4),
                )
            )
        return predictions

    def _confidence(
        self, close_prices: list[float], step: int, expected_move_abs: float
    ) -> float:
        if len(close_prices) < 3:
            return self.confidence_floor

        changes: list[float] = []
        for previous, current in zip(close_prices[:-1], close_prices[1:]):
            if previous <= 0:
                continue
            changes.append(abs((current - previous) / previous))

        realized_vol = mean(changes[-min(64, len(changes)) :]) if changes else 0.0
        stability = max(0.0, 1.0 - min(realized_vol * 18.0, 1.0))
        horizon_decay = 1.0 / (1.0 + math.log(step + 1.0))
        raw = 0.20 + min(expected_move_abs / (realized_vol + 1e-9), 2.0) * 0.18
        raw += stability * 0.20 + horizon_decay * 0.22
        return max(self.confidence_floor, min(self.confidence_ceiling, raw))

    @staticmethod
    def _slope(values: list[float]) -> float:
        if len(values) < 2 or values[0] == 0:
            return 0.0
        return (values[-1] - values[0]) / values[0]

    def describe_runtime(self) -> dict[str, Any]:
        return {
            "model_id": self.model_id,
            "model_name": self.model_name,
            "backend": self.backend,
            "device": self.device,
            "ready": self.ready,
            "max_context_length": self.max_context_length,
            "max_horizon_step": self.max_horizon_step,
            "min_required_points": self.min_required_points,
        }