""" core/tft_model.py — Temporal Fusion Transformer 分位数预测模块 ================================================================ 基于 Google Research (2021) TFT 架构, 通过 Darts 框架实现. - 内置 Variable Selection Network: 自动学习因子重要性 - 内置 Temporal Attention: 识别关键时间步 - 原生多分位数输出: Q10/Q50/Q90 - 轻量化配置 (17K params): 适配小样本时序 (~276 月) """ import warnings import logging import numpy as np import pandas as pd # Suppress verbose logging for _name in ['pytorch_lightning', 'lightning', 'pl', 'darts', 'lightning.pytorch']: logging.getLogger(_name).setLevel(logging.ERROR) warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) from darts import TimeSeries from darts.models import TFTModel from darts.utils.likelihood_models import QuantileRegression from darts.dataprocessing.transformers import Scaler class TFTQuantilePredictor: """ TFT 分位数预测器 — 用于 walk-forward 预测循环. 输出: Q10, Q50, Q90 三个分位数预测值 Architecture: - input_chunk_length=24 (回看2年月度数据) - hidden_size=16 (轻量级, 防过拟合) - lstm_layers=1, attention_heads=1 - QuantileRegression likelihood [0.10, 0.50, 0.90] - 早停 + 低 epoch (快速训练) """ def __init__(self, input_chunk_length=24, hidden_size=16, lstm_layers=1, n_heads=1, n_epochs=15, batch_size=32, use_gpu=False): self.input_chunk_length = input_chunk_length self.hidden_size = hidden_size self.lstm_layers = lstm_layers self.n_heads = n_heads self.n_epochs = n_epochs self.batch_size = batch_size self.use_gpu = use_gpu self.model = None self.scaler_y = Scaler() self.scaler_cov = Scaler() self._fitted = False def _build_model(self): """构建 TFT 模型实例。""" accelerator = 'gpu' if self.use_gpu else 'cpu' self.model = TFTModel( input_chunk_length=self.input_chunk_length, output_chunk_length=1, hidden_size=self.hidden_size, lstm_layers=self.lstm_layers, num_attention_heads=self.n_heads, dropout=0.1, likelihood=QuantileRegression(quantiles=[0.10, 0.50, 0.90]), n_epochs=self.n_epochs, batch_size=self.batch_size, add_relative_index=True, random_state=42, log_tensorboard=False, pl_trainer_kwargs={ 'enable_progress_bar': False, 'accelerator': accelerator, 'enable_model_summary': False, }, ) def fit_predict(self, y_train, X_train, X_test_row): """ 训练 TFT 并预测下一个时间步的 Q10/Q50/Q90. Args: y_train: pd.Series — 目标变量 (月度收益率), DatetimeIndex X_train: pd.DataFrame — 特征矩阵, DatetimeIndex X_test_row: pd.DataFrame — 测试特征 (1行), DatetimeIndex Returns: dict: {'tft_q10_1m': float, 'tft_q50_1m': float, 'tft_q90_1m': float} or None if training fails """ try: # Minimum samples check n = len(y_train.dropna()) if n < self.input_chunk_length + 10: return None # Align data valid_idx = y_train.dropna().index common_idx = valid_idx.intersection(X_train.index) if len(common_idx) < self.input_chunk_length + 5: return None y_aligned = y_train.loc[common_idx].sort_index() X_aligned = X_train.loc[common_idx].sort_index().fillna(0) # Normalize dates to month-start (Darts requires regular freq) def to_month_start(idx): return pd.DatetimeIndex([d.replace(day=1) for d in idx]) y_ms = y_aligned.copy() y_ms.index = to_month_start(y_ms.index) X_ms = X_aligned.copy() X_ms.index = to_month_start(X_ms.index) # Drop duplicate months (if any) y_ms = y_ms[~y_ms.index.duplicated(keep='last')] X_ms = X_ms[~X_ms.index.duplicated(keep='last')] # Ensure aligned common = y_ms.index.intersection(X_ms.index) y_ms = y_ms.loc[common] X_ms = X_ms.loc[common] if len(y_ms) < self.input_chunk_length + 5: return None # Convert to Darts TimeSeries y_ts = TimeSeries.from_times_and_values( y_ms.index, y_ms.values.reshape(-1, 1), freq='MS' ) cov_values = X_ms.values.astype(np.float32) cov_ts = TimeSeries.from_times_and_values( X_ms.index, cov_values, columns=list(X_ms.columns), freq='MS' ) # Extend covariates with test point X_test_clean = X_test_row.fillna(0).copy() X_test_clean.index = to_month_start(X_test_clean.index) cov_ext_df = pd.concat([X_ms, X_test_clean]).sort_index() cov_ext_df = cov_ext_df[~cov_ext_df.index.duplicated(keep='last')] cov_ext_ts = TimeSeries.from_times_and_values( cov_ext_df.index, cov_ext_df.values.astype(np.float32), columns=list(cov_ext_df.columns), freq='MS' ) # Build & train self._build_model() self.model.fit(y_ts, past_covariates=cov_ts, verbose=False) # Predict pred = self.model.predict( n=1, past_covariates=cov_ext_ts, num_samples=200, ) # Extract quantiles from probabilistic prediction vals = pred.all_values() # shape: (1, n_components, n_samples) samples = vals.flatten() # all 200 samples q10 = float(np.percentile(samples, 10)) q50 = float(np.percentile(samples, 50)) q90 = float(np.percentile(samples, 90)) # Sanity check if any(np.isnan(v) or np.isinf(v) for v in [q10, q50, q90]): return None # Enforce monotonicity vals = sorted([q10, q50, q90]) result = { 'tft_q10_1m': float(np.clip(vals[0], -0.50, 0.50)), 'tft_q50_1m': float(np.clip(vals[1], -0.50, 0.50)), 'tft_q90_1m': float(np.clip(vals[2], -0.50, 0.50)), } self._fitted = True return result except Exception as e: # Silent fallback — TFT failure should not crash pipeline return None def tft_predict_step(train_df, test_df, sel_features, y_col='target_ret_1m'): """ Walk-forward 单步 TFT 预测 — 供 engine.py 调用的便捷函数. Args: train_df: pd.DataFrame — 训练窗口数据 (含特征+目标) test_df: pd.DataFrame — 测试行(1行) sel_features: list — 选定特征名 y_col: str — 目标列名 Returns: dict or None: TFT quantile predictions """ predictor = TFTQuantilePredictor( input_chunk_length=min(24, len(train_df) - 5), n_epochs=10, batch_size=min(32, len(train_df) // 2), use_gpu=False, # CPU faster for small batches ) y_train = train_df[y_col] X_train = train_df[sel_features] X_test = test_df[sel_features] return predictor.fit_predict(y_train, X_train, X_test)