Spaces:
Running
Running
yingfeng64 Claude Sonnet 4.6 commited on
Commit ·
72a9562
0
Parent(s):
Initial deployment: Kronos stock predictor REST API
Browse filesMonte-Carlo probabilistic forecasting using NeoQuasar/Kronos-base.
Async task queue (POST submit / GET poll), Tushare qfq data source.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- .gitignore +5 -0
- Dockerfile +27 -0
- README.md +107 -0
- app.py +211 -0
- data_fetcher.py +77 -0
- predictor.py +121 -0
- requirements.txt +12 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.env
|
| 5 |
+
docs/
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ── Build stage ───────────────────────────────────────────────────────────────
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# System deps
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
git \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
# Clone Kronos source at build time (avoids runtime clone delay)
|
| 12 |
+
RUN git clone --depth 1 https://github.com/shiyu-coder/Kronos /app/Kronos
|
| 13 |
+
|
| 14 |
+
# Install Python deps first (layer cache friendly)
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy application source
|
| 19 |
+
COPY app.py predictor.py data_fetcher.py ./
|
| 20 |
+
|
| 21 |
+
# HuggingFace Spaces default port
|
| 22 |
+
EXPOSE 7860
|
| 23 |
+
|
| 24 |
+
# KRONOS_DIR tells predictor.py where the source lives (already cloned above)
|
| 25 |
+
ENV KRONOS_DIR=/app/Kronos
|
| 26 |
+
|
| 27 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Kronos Stock Predictor API
|
| 3 |
+
emoji: 📈
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Kronos Stock Predictor API
|
| 11 |
+
|
| 12 |
+
Monte-Carlo probabilistic stock forecasting powered by
|
| 13 |
+
[Kronos](https://arxiv.org/abs/2508.02739) — Tsinghua University's open-source
|
| 14 |
+
financial K-line foundation model.
|
| 15 |
+
|
| 16 |
+
Data source: **Tushare Pro** (front-adjusted / qfq).
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Endpoints
|
| 21 |
+
|
| 22 |
+
### `POST /api/v1/predict`
|
| 23 |
+
|
| 24 |
+
Submit a prediction job. Returns a `task_id` immediately.
|
| 25 |
+
|
| 26 |
+
**Request body**
|
| 27 |
+
|
| 28 |
+
| Field | Type | Default | Description |
|
| 29 |
+
|-------|------|---------|-------------|
|
| 30 |
+
| `ts_code` | string | — | Tushare stock code, e.g. `"600900.SH"` |
|
| 31 |
+
| `lookback` | int | 512 | Historical bars to feed the model (1–512) |
|
| 32 |
+
| `pred_len` | int | 5 | Future trading days to predict (1–60) |
|
| 33 |
+
| `sample_count` | int | 30 | MC sampling iterations (1–100) |
|
| 34 |
+
| `mode` | string | `"simple"` | `"simple"` or `"advanced"` |
|
| 35 |
+
| `include_volume` | bool | false | Include volume CI in `advanced` mode |
|
| 36 |
+
|
| 37 |
+
```json
|
| 38 |
+
{
|
| 39 |
+
"ts_code": "600900.SH",
|
| 40 |
+
"lookback": 512,
|
| 41 |
+
"pred_len": 5,
|
| 42 |
+
"sample_count": 30,
|
| 43 |
+
"mode": "simple"
|
| 44 |
+
}
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
**Response**
|
| 48 |
+
|
| 49 |
+
```json
|
| 50 |
+
{ "task_id": "550e8400-e29b-41d4-a716-446655440000" }
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
### `GET /api/v1/predict/{task_id}`
|
| 56 |
+
|
| 57 |
+
Poll for results.
|
| 58 |
+
|
| 59 |
+
```json
|
| 60 |
+
{
|
| 61 |
+
"status": "done",
|
| 62 |
+
"error": null,
|
| 63 |
+
"result": {
|
| 64 |
+
"ts_code": "600900.SH",
|
| 65 |
+
"base_date": "2026-03-13",
|
| 66 |
+
"pred_len": 5,
|
| 67 |
+
"confidence": 95,
|
| 68 |
+
"confidence_warning": false,
|
| 69 |
+
"direction": { "signal": "bullish", "probability": 0.73 },
|
| 70 |
+
"summary": {
|
| 71 |
+
"mean_close": 27.05,
|
| 72 |
+
"range_low": 25.80,
|
| 73 |
+
"range_high": 28.30,
|
| 74 |
+
"range_width": 2.50
|
| 75 |
+
},
|
| 76 |
+
"bands": [
|
| 77 |
+
{
|
| 78 |
+
"date": "2026-03-14",
|
| 79 |
+
"step": 1,
|
| 80 |
+
"mean_close": 26.88,
|
| 81 |
+
"trading_low": 26.20,
|
| 82 |
+
"trading_high": 27.55,
|
| 83 |
+
"uncertainty": 0.0504
|
| 84 |
+
}
|
| 85 |
+
]
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
`status` is one of `"pending"` / `"done"` / `"failed"`.
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
### `GET /health`
|
| 95 |
+
|
| 96 |
+
```json
|
| 97 |
+
{ "status": "ok" }
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
## Notes
|
| 103 |
+
|
| 104 |
+
- **Direction signal**: based on the last predicted close vs. the last historical close across all MC samples.
|
| 105 |
+
- **95 % trading band**: `trading_low` = q2.5 of daily predicted lows; `trading_high` = q97.5 of daily predicted highs.
|
| 106 |
+
- **`confidence_warning: true`** when `pred_len > 30` (model uncertainty grows significantly beyond ~30 days).
|
| 107 |
+
- CPU inference: ~3–5 s/sample → 30 samples ≈ 2–5 min. Consider selecting GPU hardware for production use.
|
app.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Kronos Stock Predictor — RESTful API
|
| 3 |
+
=====================================
|
| 4 |
+
POST /api/v1/predict → { "task_id": "uuid" }
|
| 5 |
+
GET /api/v1/predict/{id} → { "status": "pending|done|failed", "result": {...} }
|
| 6 |
+
GET /health → { "status": "ok" }
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import asyncio
|
| 10 |
+
import logging
|
| 11 |
+
import uuid
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 13 |
+
from contextlib import asynccontextmanager
|
| 14 |
+
from typing import Literal
|
| 15 |
+
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from fastapi import FastAPI, HTTPException
|
| 18 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
+
from pydantic import BaseModel, Field
|
| 20 |
+
|
| 21 |
+
import data_fetcher
|
| 22 |
+
import predictor as pred_module
|
| 23 |
+
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# ── Task store (in-process; sufficient for single-worker deployments) ─────────
|
| 28 |
+
_tasks: dict[str, dict] = {}
|
| 29 |
+
_executor = ThreadPoolExecutor(max_workers=2)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ── Startup: eagerly load the model so the first request isn't slow ───────────
|
| 33 |
+
@asynccontextmanager
|
| 34 |
+
async def lifespan(app: FastAPI):
|
| 35 |
+
loop = asyncio.get_event_loop()
|
| 36 |
+
logger.info("Pre-loading Kronos predictor …")
|
| 37 |
+
await loop.run_in_executor(_executor, pred_module.get_predictor)
|
| 38 |
+
logger.info("Kronos predictor ready.")
|
| 39 |
+
yield
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
app = FastAPI(
|
| 43 |
+
title="Kronos Stock Predictor API",
|
| 44 |
+
version="1.0.0",
|
| 45 |
+
description=(
|
| 46 |
+
"Monte-Carlo probabilistic stock forecasting powered by the "
|
| 47 |
+
"Kronos foundation model (Tsinghua University)."
|
| 48 |
+
),
|
| 49 |
+
lifespan=lifespan,
|
| 50 |
+
)
|
| 51 |
+
app.add_middleware(
|
| 52 |
+
CORSMiddleware,
|
| 53 |
+
allow_origins=["*"],
|
| 54 |
+
allow_methods=["*"],
|
| 55 |
+
allow_headers=["*"],
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ── Request / Response schemas ────────────────────────────────────────────────
|
| 60 |
+
class PredictRequest(BaseModel):
|
| 61 |
+
ts_code: str = Field(..., examples=["600900.SH"], description="Tushare 股票代码")
|
| 62 |
+
lookback: int = Field(
|
| 63 |
+
default=512,
|
| 64 |
+
ge=20,
|
| 65 |
+
le=512,
|
| 66 |
+
description="回看历史 K 线根数(最多 512,不足时自动截断)",
|
| 67 |
+
)
|
| 68 |
+
pred_len: int = Field(
|
| 69 |
+
default=5,
|
| 70 |
+
ge=1,
|
| 71 |
+
le=60,
|
| 72 |
+
description="预测未来交易日数(建议 ≤ 30,超过时返回 confidence_warning)",
|
| 73 |
+
)
|
| 74 |
+
sample_count: int = Field(
|
| 75 |
+
default=30,
|
| 76 |
+
ge=1,
|
| 77 |
+
le=100,
|
| 78 |
+
description="MC 蒙特卡洛采样次数",
|
| 79 |
+
)
|
| 80 |
+
mode: Literal["simple", "advanced"] = Field(
|
| 81 |
+
default="simple",
|
| 82 |
+
description="simple: 仅返回均值 + 交易区间;advanced: 追加 OHLC 均值及收盘 CI",
|
| 83 |
+
)
|
| 84 |
+
include_volume: bool = Field(
|
| 85 |
+
default=False,
|
| 86 |
+
description="mode=advanced 时是否额外返回成交量预测(默认关闭)",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ── Response builder ──────────────────────────────────────────────────────────
|
| 91 |
+
def _build_response(req: PredictRequest, base_date: str, pred_mean, ci,
|
| 92 |
+
trading_low, trading_high, direction_prob, last_close,
|
| 93 |
+
y_timestamp) -> dict:
|
| 94 |
+
bands = []
|
| 95 |
+
for i in range(req.pred_len):
|
| 96 |
+
band: dict = {
|
| 97 |
+
"date": str(y_timestamp.iloc[i].date()),
|
| 98 |
+
"step": i + 1,
|
| 99 |
+
"mean_close": round(float(pred_mean["close"].iloc[i]), 4),
|
| 100 |
+
"trading_low": round(float(trading_low[i]), 4),
|
| 101 |
+
"trading_high": round(float(trading_high[i]), 4),
|
| 102 |
+
"uncertainty": round(
|
| 103 |
+
float((trading_high[i] - trading_low[i]) / last_close), 4
|
| 104 |
+
),
|
| 105 |
+
}
|
| 106 |
+
if req.mode == "advanced":
|
| 107 |
+
band.update({
|
| 108 |
+
"mean_open": round(float(pred_mean["open"].iloc[i]), 4),
|
| 109 |
+
"mean_high": round(float(pred_mean["high"].iloc[i]), 4),
|
| 110 |
+
"mean_low": round(float(pred_mean["low"].iloc[i]), 4),
|
| 111 |
+
"close_ci_low": round(float(ci["close"]["low"][i]), 4),
|
| 112 |
+
"close_ci_high": round(float(ci["close"]["high"][i]), 4),
|
| 113 |
+
})
|
| 114 |
+
bands.append(band)
|
| 115 |
+
|
| 116 |
+
result: dict = {
|
| 117 |
+
"ts_code": req.ts_code,
|
| 118 |
+
"base_date": base_date,
|
| 119 |
+
"pred_len": req.pred_len,
|
| 120 |
+
"confidence": 95,
|
| 121 |
+
"confidence_warning": req.pred_len > 30,
|
| 122 |
+
"direction": {
|
| 123 |
+
"signal": "bullish" if direction_prob >= 0.5 else "bearish",
|
| 124 |
+
"probability": round(direction_prob, 4),
|
| 125 |
+
},
|
| 126 |
+
"summary": {
|
| 127 |
+
"mean_close": round(float(pred_mean["close"].iloc[-1]), 4),
|
| 128 |
+
"range_low": round(float(trading_low.min()), 4),
|
| 129 |
+
"range_high": round(float(trading_high.max()), 4),
|
| 130 |
+
"range_width": round(float(trading_high.max() - trading_low.min()), 4),
|
| 131 |
+
},
|
| 132 |
+
"bands": bands,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
if req.mode == "advanced" and req.include_volume:
|
| 136 |
+
result["volume"] = [
|
| 137 |
+
{
|
| 138 |
+
"date": str(y_timestamp.iloc[i].date()),
|
| 139 |
+
"mean_volume": round(float(pred_mean["volume"].iloc[i])),
|
| 140 |
+
"volume_ci_low": round(float(ci["volume"]["low"][i])),
|
| 141 |
+
"volume_ci_high": round(float(ci["volume"]["high"][i])),
|
| 142 |
+
}
|
| 143 |
+
for i in range(req.pred_len)
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
return result
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ── Background task ───────────────────────────────────────────────────────────
|
| 150 |
+
def _run_prediction(task_id: str, req: PredictRequest) -> None:
|
| 151 |
+
try:
|
| 152 |
+
x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data(
|
| 153 |
+
req.ts_code, req.lookback
|
| 154 |
+
)
|
| 155 |
+
y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len)
|
| 156 |
+
|
| 157 |
+
pred_mean, ci, trading_low, trading_high, direction_prob, last_close = (
|
| 158 |
+
pred_module.run_mc_prediction(
|
| 159 |
+
x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date())
|
| 164 |
+
result = _build_response(
|
| 165 |
+
req, base_date, pred_mean, ci,
|
| 166 |
+
trading_low, trading_high, direction_prob, last_close, y_timestamp,
|
| 167 |
+
)
|
| 168 |
+
_tasks[task_id] = {"status": "done", "result": result, "error": None}
|
| 169 |
+
except Exception as exc:
|
| 170 |
+
logger.exception("Task %s failed", task_id)
|
| 171 |
+
_tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ── Routes ────────────────────────────────────────────────────────────────────
|
| 175 |
+
@app.post(
|
| 176 |
+
"/api/v1/predict",
|
| 177 |
+
summary="提交预测任务",
|
| 178 |
+
response_description="任务 ID,用于轮询结果",
|
| 179 |
+
)
|
| 180 |
+
async def submit_predict(req: PredictRequest):
|
| 181 |
+
"""
|
| 182 |
+
提交一个蒙特卡洛预测任务,立即返回 `task_id`。
|
| 183 |
+
通过 `GET /api/v1/predict/{task_id}` 轮询结果。
|
| 184 |
+
"""
|
| 185 |
+
task_id = str(uuid.uuid4())
|
| 186 |
+
_tasks[task_id] = {"status": "pending", "result": None, "error": None}
|
| 187 |
+
_executor.submit(_run_prediction, task_id, req)
|
| 188 |
+
return {"task_id": task_id}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@app.get(
|
| 192 |
+
"/api/v1/predict/{task_id}",
|
| 193 |
+
summary="查询预测结果",
|
| 194 |
+
)
|
| 195 |
+
async def get_predict_result(task_id: str):
|
| 196 |
+
"""
|
| 197 |
+
轮询预测任务状态。
|
| 198 |
+
|
| 199 |
+
- `status: "pending"` — 正在计算
|
| 200 |
+
- `status: "done"` — 完成,`result` 字段包含预测数据
|
| 201 |
+
- `status: "failed"` — 失败,`error` 字段包含错误信息
|
| 202 |
+
"""
|
| 203 |
+
task = _tasks.get(task_id)
|
| 204 |
+
if task is None:
|
| 205 |
+
raise HTTPException(status_code=404, detail=f"Task {task_id!r} not found")
|
| 206 |
+
return task
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@app.get("/health", summary="健康检查")
|
| 210 |
+
async def health():
|
| 211 |
+
return {"status": "ok"}
|
data_fetcher.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datetime import datetime, timedelta
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import tushare as ts
|
| 6 |
+
|
| 7 |
+
TUSHARE_TOKEN = os.environ.get(
|
| 8 |
+
"TUSHARE_TOKEN",
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
ts.set_token(TUSHARE_TOKEN)
|
| 12 |
+
_pro = ts.pro_api()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def fetch_stock_data(
|
| 16 |
+
ts_code: str, lookback: int
|
| 17 |
+
) -> tuple[pd.DataFrame, pd.Series, str]:
|
| 18 |
+
"""
|
| 19 |
+
Returns:
|
| 20 |
+
x_df : DataFrame with columns [open, high, low, close, volume, amount]
|
| 21 |
+
x_timestamp : pd.Series[datetime], aligned to x_df
|
| 22 |
+
last_trade_date: str "YYYYMMDD", the most recent bar date
|
| 23 |
+
"""
|
| 24 |
+
end_date = datetime.today().strftime("%Y%m%d")
|
| 25 |
+
# 2× buffer to account for weekends/holidays
|
| 26 |
+
start_date = (datetime.today() - timedelta(days=lookback * 2)).strftime("%Y%m%d")
|
| 27 |
+
|
| 28 |
+
df = ts.pro_bar(
|
| 29 |
+
ts_code=ts_code,
|
| 30 |
+
adj="qfq",
|
| 31 |
+
start_date=start_date,
|
| 32 |
+
end_date=end_date,
|
| 33 |
+
asset="E",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if df is None or df.empty:
|
| 37 |
+
raise ValueError(f"No data returned for ts_code={ts_code!r}")
|
| 38 |
+
|
| 39 |
+
df = df.sort_values("trade_date").reset_index(drop=True)
|
| 40 |
+
df = df.rename(columns={"vol": "volume"})
|
| 41 |
+
df["timestamps"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
|
| 42 |
+
|
| 43 |
+
# Keep the most recent `lookback` bars
|
| 44 |
+
df = df.tail(lookback).reset_index(drop=True)
|
| 45 |
+
|
| 46 |
+
x_df = df[["open", "high", "low", "close", "volume", "amount"]].copy()
|
| 47 |
+
x_timestamp = df["timestamps"].copy()
|
| 48 |
+
last_trade_date = df["trade_date"].iloc[-1]
|
| 49 |
+
|
| 50 |
+
return x_df, x_timestamp, last_trade_date
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_future_trading_dates(last_trade_date: str, pred_len: int) -> pd.Series:
|
| 54 |
+
"""
|
| 55 |
+
Return a pd.Series of `pred_len` future SSE trading dates (datetime) that
|
| 56 |
+
follow `last_trade_date` (format: YYYYMMDD).
|
| 57 |
+
"""
|
| 58 |
+
last_dt = datetime.strptime(last_trade_date, "%Y%m%d")
|
| 59 |
+
# 3× buffer so we always have enough dates even over a long holiday
|
| 60 |
+
end_dt = last_dt + timedelta(days=pred_len * 3)
|
| 61 |
+
|
| 62 |
+
cal = _pro.trade_cal(
|
| 63 |
+
exchange="SSE",
|
| 64 |
+
start_date=(last_dt + timedelta(days=1)).strftime("%Y%m%d"),
|
| 65 |
+
end_date=end_dt.strftime("%Y%m%d"),
|
| 66 |
+
is_open="1",
|
| 67 |
+
)
|
| 68 |
+
cal = cal.sort_values("cal_date")
|
| 69 |
+
dates = pd.to_datetime(cal["cal_date"].values[:pred_len], format="%Y%m%d")
|
| 70 |
+
|
| 71 |
+
if len(dates) < pred_len:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
f"Could only obtain {len(dates)} future trading dates; "
|
| 74 |
+
f"increase buffer or check Tushare calendar coverage."
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return pd.Series(dates)
|
predictor.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Kronos model singleton + Monte-Carlo prediction logic.
|
| 3 |
+
|
| 4 |
+
On import this module:
|
| 5 |
+
1. Clones shiyu-coder/Kronos from GitHub if not already present at KRONOS_DIR.
|
| 6 |
+
2. Adds KRONOS_DIR to sys.path so `from model import ...` works.
|
| 7 |
+
3. Does NOT load the model weights yet (lazy, first-request).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
import sys
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
# ── Paths / IDs ─────────────────────────────────────────────────────────────
|
| 23 |
+
KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos")
|
| 24 |
+
MODEL_ID = "NeoQuasar/Kronos-base"
|
| 25 |
+
TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ── Bootstrap Kronos source ──────────────────────────────────────────────────
|
| 29 |
+
def _ensure_kronos_source() -> None:
|
| 30 |
+
if not os.path.isdir(KRONOS_DIR):
|
| 31 |
+
logger.info("Cloning Kronos source to %s …", KRONOS_DIR)
|
| 32 |
+
subprocess.run(
|
| 33 |
+
[
|
| 34 |
+
"git", "clone", "--depth", "1",
|
| 35 |
+
"https://github.com/shiyu-coder/Kronos",
|
| 36 |
+
KRONOS_DIR,
|
| 37 |
+
],
|
| 38 |
+
check=True,
|
| 39 |
+
)
|
| 40 |
+
if KRONOS_DIR not in sys.path:
|
| 41 |
+
sys.path.insert(0, KRONOS_DIR)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
_ensure_kronos_source()
|
| 45 |
+
|
| 46 |
+
from model import Kronos, KronosPredictor, KronosTokenizer # noqa: E402 (after sys.path setup)
|
| 47 |
+
|
| 48 |
+
# ── Global singleton ─────────────────────────────────────────────────────────
|
| 49 |
+
_predictor: KronosPredictor | None = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_predictor() -> KronosPredictor:
|
| 53 |
+
global _predictor
|
| 54 |
+
if _predictor is None:
|
| 55 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 56 |
+
logger.info("Loading Kronos model on %s …", device)
|
| 57 |
+
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
|
| 58 |
+
model = Kronos.from_pretrained(MODEL_ID)
|
| 59 |
+
_predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
|
| 60 |
+
logger.info("Kronos predictor ready.")
|
| 61 |
+
return _predictor
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ── Monte-Carlo prediction ────────────────────────────────────────────────────
|
| 65 |
+
def run_mc_prediction(
|
| 66 |
+
x_df: pd.DataFrame,
|
| 67 |
+
x_timestamp: pd.Series,
|
| 68 |
+
y_timestamp: pd.Series,
|
| 69 |
+
pred_len: int,
|
| 70 |
+
sample_count: int,
|
| 71 |
+
) -> Tuple[pd.DataFrame, dict, np.ndarray, np.ndarray, float, float]:
|
| 72 |
+
"""
|
| 73 |
+
Run `sample_count` independent samples (each with sample_count=1) to build
|
| 74 |
+
MC statistics.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
pred_mean : DataFrame (index=y_timestamp, cols=OHLCVA), 均值轨迹
|
| 78 |
+
ci : dict[field]["low"/"high"] → ndarray(pred_len,), 95% CI
|
| 79 |
+
trading_low : ndarray(pred_len,), q2.5 of predicted_low
|
| 80 |
+
trading_high : ndarray(pred_len,), q97.5 of predicted_high
|
| 81 |
+
direction_prob : float ∈ [0,1], fraction of samples where final close > last close
|
| 82 |
+
last_close : float, closing price of the last historical bar
|
| 83 |
+
"""
|
| 84 |
+
predictor = get_predictor()
|
| 85 |
+
samples = []
|
| 86 |
+
|
| 87 |
+
for _ in range(sample_count):
|
| 88 |
+
s = predictor.predict(
|
| 89 |
+
df=x_df,
|
| 90 |
+
x_timestamp=x_timestamp,
|
| 91 |
+
y_timestamp=y_timestamp,
|
| 92 |
+
pred_len=pred_len,
|
| 93 |
+
T=0.8,
|
| 94 |
+
top_p=0.9,
|
| 95 |
+
sample_count=1,
|
| 96 |
+
verbose=False,
|
| 97 |
+
)
|
| 98 |
+
samples.append(s)
|
| 99 |
+
|
| 100 |
+
pred_mean = pd.concat(samples).groupby(level=0).mean()
|
| 101 |
+
|
| 102 |
+
def stack(field: str) -> np.ndarray:
|
| 103 |
+
return np.stack([s[field].values for s in samples]) # (sample_count, pred_len)
|
| 104 |
+
|
| 105 |
+
alpha = 2.5 # → 95 % CI
|
| 106 |
+
ci = {
|
| 107 |
+
field: {
|
| 108 |
+
"low": np.percentile(stack(field), alpha, axis=0),
|
| 109 |
+
"high": np.percentile(stack(field), 100 - alpha, axis=0),
|
| 110 |
+
}
|
| 111 |
+
for field in ["open", "high", "low", "close", "volume"]
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
trading_low = ci["low"]["low"] # q2.5 of the predicted daily low
|
| 115 |
+
trading_high = ci["high"]["high"] # q97.5 of the predicted daily high
|
| 116 |
+
|
| 117 |
+
last_close = float(x_df["close"].iloc[-1])
|
| 118 |
+
bull_count = sum(float(s["close"].iloc[-1]) > last_close for s in samples)
|
| 119 |
+
direction_prob = bull_count / sample_count
|
| 120 |
+
|
| 121 |
+
return pred_mean, ci, trading_low, trading_high, direction_prob, last_close
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.111.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
numpy
|
| 5 |
+
pandas==2.2.2
|
| 6 |
+
torch>=2.0.0
|
| 7 |
+
einops==0.8.1
|
| 8 |
+
huggingface_hub==0.33.1
|
| 9 |
+
matplotlib==3.9.3
|
| 10 |
+
tqdm==4.67.1
|
| 11 |
+
safetensors==0.6.2
|
| 12 |
+
tushare
|