oilverse-api / core /tft_model.py
孙家明
deploy: OilVerse for HuggingFace (Node.js 18 fix)
fab9847
"""
core/tft_model.py — Temporal Fusion Transformer 分位数预测模块
================================================================
基于 Google Research (2021) TFT 架构, 通过 Darts 框架实现.
- 内置 Variable Selection Network: 自动学习因子重要性
- 内置 Temporal Attention: 识别关键时间步
- 原生多分位数输出: Q10/Q50/Q90
- 轻量化配置 (17K params): 适配小样本时序 (~276 月)
"""
import warnings
import logging
import numpy as np
import pandas as pd
# Suppress verbose logging
for _name in ['pytorch_lightning', 'lightning', 'pl', 'darts', 'lightning.pytorch']:
logging.getLogger(_name).setLevel(logging.ERROR)
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
from darts import TimeSeries
from darts.models import TFTModel
from darts.utils.likelihood_models import QuantileRegression
from darts.dataprocessing.transformers import Scaler
class TFTQuantilePredictor:
"""
TFT 分位数预测器 — 用于 walk-forward 预测循环.
输出: Q10, Q50, Q90 三个分位数预测值
Architecture:
- input_chunk_length=24 (回看2年月度数据)
- hidden_size=16 (轻量级, 防过拟合)
- lstm_layers=1, attention_heads=1
- QuantileRegression likelihood [0.10, 0.50, 0.90]
- 早停 + 低 epoch (快速训练)
"""
def __init__(self, input_chunk_length=24, hidden_size=16,
lstm_layers=1, n_heads=1, n_epochs=15,
batch_size=32, use_gpu=False):
self.input_chunk_length = input_chunk_length
self.hidden_size = hidden_size
self.lstm_layers = lstm_layers
self.n_heads = n_heads
self.n_epochs = n_epochs
self.batch_size = batch_size
self.use_gpu = use_gpu
self.model = None
self.scaler_y = Scaler()
self.scaler_cov = Scaler()
self._fitted = False
def _build_model(self):
"""构建 TFT 模型实例。"""
accelerator = 'gpu' if self.use_gpu else 'cpu'
self.model = TFTModel(
input_chunk_length=self.input_chunk_length,
output_chunk_length=1,
hidden_size=self.hidden_size,
lstm_layers=self.lstm_layers,
num_attention_heads=self.n_heads,
dropout=0.1,
likelihood=QuantileRegression(quantiles=[0.10, 0.50, 0.90]),
n_epochs=self.n_epochs,
batch_size=self.batch_size,
add_relative_index=True,
random_state=42,
log_tensorboard=False,
pl_trainer_kwargs={
'enable_progress_bar': False,
'accelerator': accelerator,
'enable_model_summary': False,
},
)
def fit_predict(self, y_train, X_train, X_test_row):
"""
训练 TFT 并预测下一个时间步的 Q10/Q50/Q90.
Args:
y_train: pd.Series — 目标变量 (月度收益率), DatetimeIndex
X_train: pd.DataFrame — 特征矩阵, DatetimeIndex
X_test_row: pd.DataFrame — 测试特征 (1行), DatetimeIndex
Returns:
dict: {'tft_q10_1m': float, 'tft_q50_1m': float, 'tft_q90_1m': float}
or None if training fails
"""
try:
# Minimum samples check
n = len(y_train.dropna())
if n < self.input_chunk_length + 10:
return None
# Align data
valid_idx = y_train.dropna().index
common_idx = valid_idx.intersection(X_train.index)
if len(common_idx) < self.input_chunk_length + 5:
return None
y_aligned = y_train.loc[common_idx].sort_index()
X_aligned = X_train.loc[common_idx].sort_index().fillna(0)
# Normalize dates to month-start (Darts requires regular freq)
def to_month_start(idx):
return pd.DatetimeIndex([d.replace(day=1) for d in idx])
y_ms = y_aligned.copy()
y_ms.index = to_month_start(y_ms.index)
X_ms = X_aligned.copy()
X_ms.index = to_month_start(X_ms.index)
# Drop duplicate months (if any)
y_ms = y_ms[~y_ms.index.duplicated(keep='last')]
X_ms = X_ms[~X_ms.index.duplicated(keep='last')]
# Ensure aligned
common = y_ms.index.intersection(X_ms.index)
y_ms = y_ms.loc[common]
X_ms = X_ms.loc[common]
if len(y_ms) < self.input_chunk_length + 5:
return None
# Convert to Darts TimeSeries
y_ts = TimeSeries.from_times_and_values(
y_ms.index, y_ms.values.reshape(-1, 1),
freq='MS'
)
cov_values = X_ms.values.astype(np.float32)
cov_ts = TimeSeries.from_times_and_values(
X_ms.index, cov_values,
columns=list(X_ms.columns), freq='MS'
)
# Extend covariates with test point
X_test_clean = X_test_row.fillna(0).copy()
X_test_clean.index = to_month_start(X_test_clean.index)
cov_ext_df = pd.concat([X_ms, X_test_clean]).sort_index()
cov_ext_df = cov_ext_df[~cov_ext_df.index.duplicated(keep='last')]
cov_ext_ts = TimeSeries.from_times_and_values(
cov_ext_df.index, cov_ext_df.values.astype(np.float32),
columns=list(cov_ext_df.columns), freq='MS'
)
# Build & train
self._build_model()
self.model.fit(y_ts, past_covariates=cov_ts, verbose=False)
# Predict
pred = self.model.predict(
n=1,
past_covariates=cov_ext_ts,
num_samples=200,
)
# Extract quantiles from probabilistic prediction
vals = pred.all_values() # shape: (1, n_components, n_samples)
samples = vals.flatten() # all 200 samples
q10 = float(np.percentile(samples, 10))
q50 = float(np.percentile(samples, 50))
q90 = float(np.percentile(samples, 90))
# Sanity check
if any(np.isnan(v) or np.isinf(v) for v in [q10, q50, q90]):
return None
# Enforce monotonicity
vals = sorted([q10, q50, q90])
result = {
'tft_q10_1m': float(np.clip(vals[0], -0.50, 0.50)),
'tft_q50_1m': float(np.clip(vals[1], -0.50, 0.50)),
'tft_q90_1m': float(np.clip(vals[2], -0.50, 0.50)),
}
self._fitted = True
return result
except Exception as e:
# Silent fallback — TFT failure should not crash pipeline
return None
def tft_predict_step(train_df, test_df, sel_features, y_col='target_ret_1m'):
"""
Walk-forward 单步 TFT 预测 — 供 engine.py 调用的便捷函数.
Args:
train_df: pd.DataFrame — 训练窗口数据 (含特征+目标)
test_df: pd.DataFrame — 测试行(1行)
sel_features: list — 选定特征名
y_col: str — 目标列名
Returns:
dict or None: TFT quantile predictions
"""
predictor = TFTQuantilePredictor(
input_chunk_length=min(24, len(train_df) - 5),
n_epochs=10,
batch_size=min(32, len(train_df) // 2),
use_gpu=False, # CPU faster for small batches
)
y_train = train_df[y_col]
X_train = train_df[sel_features]
X_test = test_df[sel_features]
return predictor.fit_predict(y_train, X_train, X_test)