Spaces:
Build error
Build error
| """ | |
| core/tft_model.py — Temporal Fusion Transformer 分位数预测模块 | |
| ================================================================ | |
| 基于 Google Research (2021) TFT 架构, 通过 Darts 框架实现. | |
| - 内置 Variable Selection Network: 自动学习因子重要性 | |
| - 内置 Temporal Attention: 识别关键时间步 | |
| - 原生多分位数输出: Q10/Q50/Q90 | |
| - 轻量化配置 (17K params): 适配小样本时序 (~276 月) | |
| """ | |
| import warnings | |
| import logging | |
| import numpy as np | |
| import pandas as pd | |
| # Suppress verbose logging | |
| for _name in ['pytorch_lightning', 'lightning', 'pl', 'darts', 'lightning.pytorch']: | |
| logging.getLogger(_name).setLevel(logging.ERROR) | |
| warnings.filterwarnings('ignore', category=FutureWarning) | |
| warnings.filterwarnings('ignore', category=UserWarning) | |
| from darts import TimeSeries | |
| from darts.models import TFTModel | |
| from darts.utils.likelihood_models import QuantileRegression | |
| from darts.dataprocessing.transformers import Scaler | |
| class TFTQuantilePredictor: | |
| """ | |
| TFT 分位数预测器 — 用于 walk-forward 预测循环. | |
| 输出: Q10, Q50, Q90 三个分位数预测值 | |
| Architecture: | |
| - input_chunk_length=24 (回看2年月度数据) | |
| - hidden_size=16 (轻量级, 防过拟合) | |
| - lstm_layers=1, attention_heads=1 | |
| - QuantileRegression likelihood [0.10, 0.50, 0.90] | |
| - 早停 + 低 epoch (快速训练) | |
| """ | |
| def __init__(self, input_chunk_length=24, hidden_size=16, | |
| lstm_layers=1, n_heads=1, n_epochs=15, | |
| batch_size=32, use_gpu=False): | |
| self.input_chunk_length = input_chunk_length | |
| self.hidden_size = hidden_size | |
| self.lstm_layers = lstm_layers | |
| self.n_heads = n_heads | |
| self.n_epochs = n_epochs | |
| self.batch_size = batch_size | |
| self.use_gpu = use_gpu | |
| self.model = None | |
| self.scaler_y = Scaler() | |
| self.scaler_cov = Scaler() | |
| self._fitted = False | |
| def _build_model(self): | |
| """构建 TFT 模型实例。""" | |
| accelerator = 'gpu' if self.use_gpu else 'cpu' | |
| self.model = TFTModel( | |
| input_chunk_length=self.input_chunk_length, | |
| output_chunk_length=1, | |
| hidden_size=self.hidden_size, | |
| lstm_layers=self.lstm_layers, | |
| num_attention_heads=self.n_heads, | |
| dropout=0.1, | |
| likelihood=QuantileRegression(quantiles=[0.10, 0.50, 0.90]), | |
| n_epochs=self.n_epochs, | |
| batch_size=self.batch_size, | |
| add_relative_index=True, | |
| random_state=42, | |
| log_tensorboard=False, | |
| pl_trainer_kwargs={ | |
| 'enable_progress_bar': False, | |
| 'accelerator': accelerator, | |
| 'enable_model_summary': False, | |
| }, | |
| ) | |
| def fit_predict(self, y_train, X_train, X_test_row): | |
| """ | |
| 训练 TFT 并预测下一个时间步的 Q10/Q50/Q90. | |
| Args: | |
| y_train: pd.Series — 目标变量 (月度收益率), DatetimeIndex | |
| X_train: pd.DataFrame — 特征矩阵, DatetimeIndex | |
| X_test_row: pd.DataFrame — 测试特征 (1行), DatetimeIndex | |
| Returns: | |
| dict: {'tft_q10_1m': float, 'tft_q50_1m': float, 'tft_q90_1m': float} | |
| or None if training fails | |
| """ | |
| try: | |
| # Minimum samples check | |
| n = len(y_train.dropna()) | |
| if n < self.input_chunk_length + 10: | |
| return None | |
| # Align data | |
| valid_idx = y_train.dropna().index | |
| common_idx = valid_idx.intersection(X_train.index) | |
| if len(common_idx) < self.input_chunk_length + 5: | |
| return None | |
| y_aligned = y_train.loc[common_idx].sort_index() | |
| X_aligned = X_train.loc[common_idx].sort_index().fillna(0) | |
| # Normalize dates to month-start (Darts requires regular freq) | |
| def to_month_start(idx): | |
| return pd.DatetimeIndex([d.replace(day=1) for d in idx]) | |
| y_ms = y_aligned.copy() | |
| y_ms.index = to_month_start(y_ms.index) | |
| X_ms = X_aligned.copy() | |
| X_ms.index = to_month_start(X_ms.index) | |
| # Drop duplicate months (if any) | |
| y_ms = y_ms[~y_ms.index.duplicated(keep='last')] | |
| X_ms = X_ms[~X_ms.index.duplicated(keep='last')] | |
| # Ensure aligned | |
| common = y_ms.index.intersection(X_ms.index) | |
| y_ms = y_ms.loc[common] | |
| X_ms = X_ms.loc[common] | |
| if len(y_ms) < self.input_chunk_length + 5: | |
| return None | |
| # Convert to Darts TimeSeries | |
| y_ts = TimeSeries.from_times_and_values( | |
| y_ms.index, y_ms.values.reshape(-1, 1), | |
| freq='MS' | |
| ) | |
| cov_values = X_ms.values.astype(np.float32) | |
| cov_ts = TimeSeries.from_times_and_values( | |
| X_ms.index, cov_values, | |
| columns=list(X_ms.columns), freq='MS' | |
| ) | |
| # Extend covariates with test point | |
| X_test_clean = X_test_row.fillna(0).copy() | |
| X_test_clean.index = to_month_start(X_test_clean.index) | |
| cov_ext_df = pd.concat([X_ms, X_test_clean]).sort_index() | |
| cov_ext_df = cov_ext_df[~cov_ext_df.index.duplicated(keep='last')] | |
| cov_ext_ts = TimeSeries.from_times_and_values( | |
| cov_ext_df.index, cov_ext_df.values.astype(np.float32), | |
| columns=list(cov_ext_df.columns), freq='MS' | |
| ) | |
| # Build & train | |
| self._build_model() | |
| self.model.fit(y_ts, past_covariates=cov_ts, verbose=False) | |
| # Predict | |
| pred = self.model.predict( | |
| n=1, | |
| past_covariates=cov_ext_ts, | |
| num_samples=200, | |
| ) | |
| # Extract quantiles from probabilistic prediction | |
| vals = pred.all_values() # shape: (1, n_components, n_samples) | |
| samples = vals.flatten() # all 200 samples | |
| q10 = float(np.percentile(samples, 10)) | |
| q50 = float(np.percentile(samples, 50)) | |
| q90 = float(np.percentile(samples, 90)) | |
| # Sanity check | |
| if any(np.isnan(v) or np.isinf(v) for v in [q10, q50, q90]): | |
| return None | |
| # Enforce monotonicity | |
| vals = sorted([q10, q50, q90]) | |
| result = { | |
| 'tft_q10_1m': float(np.clip(vals[0], -0.50, 0.50)), | |
| 'tft_q50_1m': float(np.clip(vals[1], -0.50, 0.50)), | |
| 'tft_q90_1m': float(np.clip(vals[2], -0.50, 0.50)), | |
| } | |
| self._fitted = True | |
| return result | |
| except Exception as e: | |
| # Silent fallback — TFT failure should not crash pipeline | |
| return None | |
| def tft_predict_step(train_df, test_df, sel_features, y_col='target_ret_1m'): | |
| """ | |
| Walk-forward 单步 TFT 预测 — 供 engine.py 调用的便捷函数. | |
| Args: | |
| train_df: pd.DataFrame — 训练窗口数据 (含特征+目标) | |
| test_df: pd.DataFrame — 测试行(1行) | |
| sel_features: list — 选定特征名 | |
| y_col: str — 目标列名 | |
| Returns: | |
| dict or None: TFT quantile predictions | |
| """ | |
| predictor = TFTQuantilePredictor( | |
| input_chunk_length=min(24, len(train_df) - 5), | |
| n_epochs=10, | |
| batch_size=min(32, len(train_df) // 2), | |
| use_gpu=False, # CPU faster for small batches | |
| ) | |
| y_train = train_df[y_col] | |
| X_train = train_df[sel_features] | |
| X_test = test_df[sel_features] | |
| return predictor.fit_predict(y_train, X_train, X_test) | |