yingfeng64 Claude Sonnet 4.6 commited on
Commit
72a9562
·
0 Parent(s):

Initial deployment: Kronos stock predictor REST API

Browse files

Monte-Carlo probabilistic forecasting using NeoQuasar/Kronos-base.
Async task queue (POST submit / GET poll), Tushare qfq data source.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (7) hide show
  1. .gitignore +5 -0
  2. Dockerfile +27 -0
  3. README.md +107 -0
  4. app.py +211 -0
  5. data_fetcher.py +77 -0
  6. predictor.py +121 -0
  7. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .env
5
+ docs/
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Build stage ───────────────────────────────────────────────────────────────
2
+ FROM python:3.11-slim
3
+
4
+ # System deps
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ git \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ WORKDIR /app
10
+
11
+ # Clone Kronos source at build time (avoids runtime clone delay)
12
+ RUN git clone --depth 1 https://github.com/shiyu-coder/Kronos /app/Kronos
13
+
14
+ # Install Python deps first (layer cache friendly)
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application source
19
+ COPY app.py predictor.py data_fetcher.py ./
20
+
21
+ # HuggingFace Spaces default port
22
+ EXPOSE 7860
23
+
24
+ # KRONOS_DIR tells predictor.py where the source lives (already cloned above)
25
+ ENV KRONOS_DIR=/app/Kronos
26
+
27
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Kronos Stock Predictor API
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # Kronos Stock Predictor API
11
+
12
+ Monte-Carlo probabilistic stock forecasting powered by
13
+ [Kronos](https://arxiv.org/abs/2508.02739) — Tsinghua University's open-source
14
+ financial K-line foundation model.
15
+
16
+ Data source: **Tushare Pro** (front-adjusted / qfq).
17
+
18
+ ---
19
+
20
+ ## Endpoints
21
+
22
+ ### `POST /api/v1/predict`
23
+
24
+ Submit a prediction job. Returns a `task_id` immediately.
25
+
26
+ **Request body**
27
+
28
+ | Field | Type | Default | Description |
29
+ |-------|------|---------|-------------|
30
+ | `ts_code` | string | — | Tushare stock code, e.g. `"600900.SH"` |
31
+ | `lookback` | int | 512 | Historical bars to feed the model (1–512) |
32
+ | `pred_len` | int | 5 | Future trading days to predict (1–60) |
33
+ | `sample_count` | int | 30 | MC sampling iterations (1–100) |
34
+ | `mode` | string | `"simple"` | `"simple"` or `"advanced"` |
35
+ | `include_volume` | bool | false | Include volume CI in `advanced` mode |
36
+
37
+ ```json
38
+ {
39
+ "ts_code": "600900.SH",
40
+ "lookback": 512,
41
+ "pred_len": 5,
42
+ "sample_count": 30,
43
+ "mode": "simple"
44
+ }
45
+ ```
46
+
47
+ **Response**
48
+
49
+ ```json
50
+ { "task_id": "550e8400-e29b-41d4-a716-446655440000" }
51
+ ```
52
+
53
+ ---
54
+
55
+ ### `GET /api/v1/predict/{task_id}`
56
+
57
+ Poll for results.
58
+
59
+ ```json
60
+ {
61
+ "status": "done",
62
+ "error": null,
63
+ "result": {
64
+ "ts_code": "600900.SH",
65
+ "base_date": "2026-03-13",
66
+ "pred_len": 5,
67
+ "confidence": 95,
68
+ "confidence_warning": false,
69
+ "direction": { "signal": "bullish", "probability": 0.73 },
70
+ "summary": {
71
+ "mean_close": 27.05,
72
+ "range_low": 25.80,
73
+ "range_high": 28.30,
74
+ "range_width": 2.50
75
+ },
76
+ "bands": [
77
+ {
78
+ "date": "2026-03-14",
79
+ "step": 1,
80
+ "mean_close": 26.88,
81
+ "trading_low": 26.20,
82
+ "trading_high": 27.55,
83
+ "uncertainty": 0.0504
84
+ }
85
+ ]
86
+ }
87
+ }
88
+ ```
89
+
90
+ `status` is one of `"pending"` / `"done"` / `"failed"`.
91
+
92
+ ---
93
+
94
+ ### `GET /health`
95
+
96
+ ```json
97
+ { "status": "ok" }
98
+ ```
99
+
100
+ ---
101
+
102
+ ## Notes
103
+
104
+ - **Direction signal**: based on the last predicted close vs. the last historical close across all MC samples.
105
+ - **95 % trading band**: `trading_low` = q2.5 of daily predicted lows; `trading_high` = q97.5 of daily predicted highs.
106
+ - **`confidence_warning: true`** when `pred_len > 30` (model uncertainty grows significantly beyond ~30 days).
107
+ - CPU inference: ~3–5 s/sample → 30 samples ≈ 2–5 min. Consider selecting GPU hardware for production use.
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kronos Stock Predictor — RESTful API
3
+ =====================================
4
+ POST /api/v1/predict → { "task_id": "uuid" }
5
+ GET /api/v1/predict/{id} → { "status": "pending|done|failed", "result": {...} }
6
+ GET /health → { "status": "ok" }
7
+ """
8
+
9
+ import asyncio
10
+ import logging
11
+ import uuid
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from contextlib import asynccontextmanager
14
+ from typing import Literal
15
+
16
+ import pandas as pd
17
+ from fastapi import FastAPI, HTTPException
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from pydantic import BaseModel, Field
20
+
21
+ import data_fetcher
22
+ import predictor as pred_module
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # ── Task store (in-process; sufficient for single-worker deployments) ─────────
28
+ _tasks: dict[str, dict] = {}
29
+ _executor = ThreadPoolExecutor(max_workers=2)
30
+
31
+
32
+ # ── Startup: eagerly load the model so the first request isn't slow ───────────
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ loop = asyncio.get_event_loop()
36
+ logger.info("Pre-loading Kronos predictor …")
37
+ await loop.run_in_executor(_executor, pred_module.get_predictor)
38
+ logger.info("Kronos predictor ready.")
39
+ yield
40
+
41
+
42
+ app = FastAPI(
43
+ title="Kronos Stock Predictor API",
44
+ version="1.0.0",
45
+ description=(
46
+ "Monte-Carlo probabilistic stock forecasting powered by the "
47
+ "Kronos foundation model (Tsinghua University)."
48
+ ),
49
+ lifespan=lifespan,
50
+ )
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=["*"],
54
+ allow_methods=["*"],
55
+ allow_headers=["*"],
56
+ )
57
+
58
+
59
+ # ── Request / Response schemas ────────────────────────────────────────────────
60
+ class PredictRequest(BaseModel):
61
+ ts_code: str = Field(..., examples=["600900.SH"], description="Tushare 股票代码")
62
+ lookback: int = Field(
63
+ default=512,
64
+ ge=20,
65
+ le=512,
66
+ description="回看历史 K 线根数(最多 512,不足时自动截断)",
67
+ )
68
+ pred_len: int = Field(
69
+ default=5,
70
+ ge=1,
71
+ le=60,
72
+ description="预测未来交易日数(建议 ≤ 30,超过时返回 confidence_warning)",
73
+ )
74
+ sample_count: int = Field(
75
+ default=30,
76
+ ge=1,
77
+ le=100,
78
+ description="MC 蒙特卡洛采样次数",
79
+ )
80
+ mode: Literal["simple", "advanced"] = Field(
81
+ default="simple",
82
+ description="simple: 仅返回均值 + 交易区间;advanced: 追加 OHLC 均值及收盘 CI",
83
+ )
84
+ include_volume: bool = Field(
85
+ default=False,
86
+ description="mode=advanced 时是否额外返回成交量预测(默认关闭)",
87
+ )
88
+
89
+
90
+ # ── Response builder ──────────────────────────────────────────────────────────
91
+ def _build_response(req: PredictRequest, base_date: str, pred_mean, ci,
92
+ trading_low, trading_high, direction_prob, last_close,
93
+ y_timestamp) -> dict:
94
+ bands = []
95
+ for i in range(req.pred_len):
96
+ band: dict = {
97
+ "date": str(y_timestamp.iloc[i].date()),
98
+ "step": i + 1,
99
+ "mean_close": round(float(pred_mean["close"].iloc[i]), 4),
100
+ "trading_low": round(float(trading_low[i]), 4),
101
+ "trading_high": round(float(trading_high[i]), 4),
102
+ "uncertainty": round(
103
+ float((trading_high[i] - trading_low[i]) / last_close), 4
104
+ ),
105
+ }
106
+ if req.mode == "advanced":
107
+ band.update({
108
+ "mean_open": round(float(pred_mean["open"].iloc[i]), 4),
109
+ "mean_high": round(float(pred_mean["high"].iloc[i]), 4),
110
+ "mean_low": round(float(pred_mean["low"].iloc[i]), 4),
111
+ "close_ci_low": round(float(ci["close"]["low"][i]), 4),
112
+ "close_ci_high": round(float(ci["close"]["high"][i]), 4),
113
+ })
114
+ bands.append(band)
115
+
116
+ result: dict = {
117
+ "ts_code": req.ts_code,
118
+ "base_date": base_date,
119
+ "pred_len": req.pred_len,
120
+ "confidence": 95,
121
+ "confidence_warning": req.pred_len > 30,
122
+ "direction": {
123
+ "signal": "bullish" if direction_prob >= 0.5 else "bearish",
124
+ "probability": round(direction_prob, 4),
125
+ },
126
+ "summary": {
127
+ "mean_close": round(float(pred_mean["close"].iloc[-1]), 4),
128
+ "range_low": round(float(trading_low.min()), 4),
129
+ "range_high": round(float(trading_high.max()), 4),
130
+ "range_width": round(float(trading_high.max() - trading_low.min()), 4),
131
+ },
132
+ "bands": bands,
133
+ }
134
+
135
+ if req.mode == "advanced" and req.include_volume:
136
+ result["volume"] = [
137
+ {
138
+ "date": str(y_timestamp.iloc[i].date()),
139
+ "mean_volume": round(float(pred_mean["volume"].iloc[i])),
140
+ "volume_ci_low": round(float(ci["volume"]["low"][i])),
141
+ "volume_ci_high": round(float(ci["volume"]["high"][i])),
142
+ }
143
+ for i in range(req.pred_len)
144
+ ]
145
+
146
+ return result
147
+
148
+
149
+ # ── Background task ───────────────────────────────────────────────────────────
150
+ def _run_prediction(task_id: str, req: PredictRequest) -> None:
151
+ try:
152
+ x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data(
153
+ req.ts_code, req.lookback
154
+ )
155
+ y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len)
156
+
157
+ pred_mean, ci, trading_low, trading_high, direction_prob, last_close = (
158
+ pred_module.run_mc_prediction(
159
+ x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count
160
+ )
161
+ )
162
+
163
+ base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date())
164
+ result = _build_response(
165
+ req, base_date, pred_mean, ci,
166
+ trading_low, trading_high, direction_prob, last_close, y_timestamp,
167
+ )
168
+ _tasks[task_id] = {"status": "done", "result": result, "error": None}
169
+ except Exception as exc:
170
+ logger.exception("Task %s failed", task_id)
171
+ _tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)}
172
+
173
+
174
+ # ── Routes ────────────────────────────────────────────────────────────────────
175
+ @app.post(
176
+ "/api/v1/predict",
177
+ summary="提交预测任务",
178
+ response_description="任务 ID,用于轮询结果",
179
+ )
180
+ async def submit_predict(req: PredictRequest):
181
+ """
182
+ 提交一个蒙特卡洛预测任务,立即返回 `task_id`。
183
+ 通过 `GET /api/v1/predict/{task_id}` 轮询结果。
184
+ """
185
+ task_id = str(uuid.uuid4())
186
+ _tasks[task_id] = {"status": "pending", "result": None, "error": None}
187
+ _executor.submit(_run_prediction, task_id, req)
188
+ return {"task_id": task_id}
189
+
190
+
191
+ @app.get(
192
+ "/api/v1/predict/{task_id}",
193
+ summary="查询预测结果",
194
+ )
195
+ async def get_predict_result(task_id: str):
196
+ """
197
+ 轮询预测任务状态。
198
+
199
+ - `status: "pending"` — 正在计算
200
+ - `status: "done"` — 完成,`result` 字段包含预测数据
201
+ - `status: "failed"` — 失败,`error` 字段包含错误信息
202
+ """
203
+ task = _tasks.get(task_id)
204
+ if task is None:
205
+ raise HTTPException(status_code=404, detail=f"Task {task_id!r} not found")
206
+ return task
207
+
208
+
209
+ @app.get("/health", summary="健康检查")
210
+ async def health():
211
+ return {"status": "ok"}
data_fetcher.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime, timedelta
3
+
4
+ import pandas as pd
5
+ import tushare as ts
6
+
7
+ TUSHARE_TOKEN = os.environ.get(
8
+ "TUSHARE_TOKEN",
9
+ )
10
+
11
+ ts.set_token(TUSHARE_TOKEN)
12
+ _pro = ts.pro_api()
13
+
14
+
15
+ def fetch_stock_data(
16
+ ts_code: str, lookback: int
17
+ ) -> tuple[pd.DataFrame, pd.Series, str]:
18
+ """
19
+ Returns:
20
+ x_df : DataFrame with columns [open, high, low, close, volume, amount]
21
+ x_timestamp : pd.Series[datetime], aligned to x_df
22
+ last_trade_date: str "YYYYMMDD", the most recent bar date
23
+ """
24
+ end_date = datetime.today().strftime("%Y%m%d")
25
+ # 2× buffer to account for weekends/holidays
26
+ start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d")
27
+
28
+ df = ts.pro_bar(
29
+ ts_code=ts_code,
30
+ adj="qfq",
31
+ start_date=start_date,
32
+ end_date=end_date,
33
+ asset="E",
34
+ )
35
+
36
+ if df is None or df.empty:
37
+ raise ValueError(f"No data returned for ts_code={ts_code!r}")
38
+
39
+ df = df.sort_values("trade_date").reset_index(drop=True)
40
+ df = df.rename(columns={"vol": "volume"})
41
+ df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
42
+
43
+ # Keep the most recent `lookback` bars
44
+ df = df.tail(lookback).reset_index(drop=True)
45
+
46
+ x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
47
+ x_timestamp = df["timestamps"].copy()
48
+ last_trade_date = df["trade_date"].iloc[-1]
49
+
50
+ return x_df, x_timestamp, last_trade_date
51
+
52
+
53
+ def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
54
+ """
55
+ Return a pd.Series of `pred_len` future SSE trading dates (datetime) that
56
+ follow `last_trade_date` (format: YYYYMMDD).
57
+ """
58
+ last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
59
+ # 3× buffer so we always have enough dates even over a long holiday
60
+ end_dt = last_dt + timedelta(days=pred_len * 3)
61
+
62
+ cal = _pro.trade_cal(
63
+ exchange="SSE",
64
+ start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"),
65
+ end_date=end_dt.strftime("%Y%m%d"),
66
+ is_open="1",
67
+ )
68
+ cal = cal.sort_values("cal_date")
69
+ dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d")
70
+
71
+ if len(dates) < pred_len:
72
+ raise ValueError(
73
+ f"Could only obtain {len(dates)} future trading dates; "
74
+ f"increase buffer or check Tushare calendar coverage."
75
+ )
76
+
77
+ return pd.Series(dates)
predictor.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kronos model singleton + Monte-Carlo prediction logic.
3
+
4
+ On import this module:
5
+ 1. Clones shiyu-coder/Kronos from GitHub if not already present at KRONOS_DIR.
6
+ 2. Adds KRONOS_DIR to sys.path so `from model import ...` works.
7
+ 3. Does NOT load the model weights yet (lazy, first-request).
8
+ """
9
+
10
+ import logging
11
+ import os
12
+ import subprocess
13
+ import sys
14
+ from typing import Tuple
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # ── Paths / IDs ─────────────────────────────────────────────────────────────
23
+ KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos")
24
+ MODEL_ID = "NeoQuasar/Kronos-base"
25
+ TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
26
+
27
+
28
+ # ── Bootstrap Kronos source ──────────────────────────────────────────────────
29
+ def _ensure_kronos_source() -> None:
30
+ if not os.path.isdir(KRONOS_DIR):
31
+ logger.info("Cloning Kronos source to %s …", KRONOS_DIR)
32
+ subprocess.run(
33
+ [
34
+ "git", "clone", "--depth", "1",
35
+ "https://github.com/shiyu-coder/Kronos",
36
+ KRONOS_DIR,
37
+ ],
38
+ check=True,
39
+ )
40
+ if KRONOS_DIR not in sys.path:
41
+ sys.path.insert(0, KRONOS_DIR)
42
+
43
+
44
+ _ensure_kronos_source()
45
+
46
+ from model import Kronos, KronosPredictor, KronosTokenizer # noqa: E402 (after sys.path setup)
47
+
48
+ # ── Global singleton ─────────────────────────────────────────────────────────
49
+ _predictor: KronosPredictor | None = None
50
+
51
+
52
+ def get_predictor() -> KronosPredictor:
53
+ global _predictor
54
+ if _predictor is None:
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ logger.info("Loading Kronos model on %s …", device)
57
+ tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
58
+ model = Kronos.from_pretrained(MODEL_ID)
59
+ _predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
60
+ logger.info("Kronos predictor ready.")
61
+ return _predictor
62
+
63
+
64
+ # ── Monte-Carlo prediction ────────────────────────────────────────────────────
65
+ def run_mc_prediction(
66
+ x_df: pd.DataFrame,
67
+ x_timestamp: pd.Series,
68
+ y_timestamp: pd.Series,
69
+ pred_len: int,
70
+ sample_count: int,
71
+ ) -> Tuple[pd.DataFrame, dict, np.ndarray, np.ndarray, float, float]:
72
+ """
73
+ Run `sample_count` independent samples (each with sample_count=1) to build
74
+ MC statistics.
75
+
76
+ Returns:
77
+ pred_mean : DataFrame (index=y_timestamp, cols=OHLCVA), 均值轨迹
78
+ ci : dict[field]["low"/"high"] → ndarray(pred_len,), 95% CI
79
+ trading_low : ndarray(pred_len,), q2.5 of predicted_low
80
+ trading_high : ndarray(pred_len,), q97.5 of predicted_high
81
+ direction_prob : float ∈ [0,1], fraction of samples where final close > last close
82
+ last_close : float, closing price of the last historical bar
83
+ """
84
+ predictor = get_predictor()
85
+ samples = []
86
+
87
+ for _ in range(sample_count):
88
+ s = predictor.predict(
89
+ df=x_df,
90
+ x_timestamp=x_timestamp,
91
+ y_timestamp=y_timestamp,
92
+ pred_len=pred_len,
93
+ T=0.8,
94
+ top_p=0.9,
95
+ sample_count=1,
96
+ verbose=False,
97
+ )
98
+ samples.append(s)
99
+
100
+ pred_mean = pd.concat(samples).groupby(level=0).mean()
101
+
102
+ def stack(field: str) -> np.ndarray:
103
+ return np.stack([s[field].values for s in samples]) # (sample_count, pred_len)
104
+
105
+ alpha = 2.5 # → 95 % CI
106
+ ci = {
107
+ field: {
108
+ "low": np.percentile(stack(field), alpha, axis=0),
109
+ "high": np.percentile(stack(field), 100 - alpha, axis=0),
110
+ }
111
+ for field in ["open", "high", "low", "close", "volume"]
112
+ }
113
+
114
+ trading_low = ci["low"]["low"] # q2.5 of the predicted daily low
115
+ trading_high = ci["high"]["high"] # q97.5 of the predicted daily high
116
+
117
+ last_close = float(x_df["close"].iloc[-1])
118
+ bull_count = sum(float(s["close"].iloc[-1]) > last_close for s in samples)
119
+ direction_prob = bull_count / sample_count
120
+
121
+ return pred_mean, ci, trading_low, trading_high, direction_prob, last_close
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.111.0
2
+ uvicorn[standard]>=0.29.0
3
+ pydantic>=2.0.0
4
+ numpy
5
+ pandas==2.2.2
6
+ torch>=2.0.0
7
+ einops==0.8.1
8
+ huggingface_hub==0.33.1
9
+ matplotlib==3.9.3
10
+ tqdm==4.67.1
11
+ safetensors==0.6.2
12
+ tushare