Spaces:
Build error
Build error
| """ | |
| run.py — 主入口:数据更新 → 引擎 → 分析 → 对冲 → 因果 → Dashboard | |
| =================================================================== | |
| 用法: | |
| python run.py # 完整流程(含数据更新+新闻) | |
| python run.py --skip-update # 跳过数据更新,直接用现有数据 | |
| """ | |
| import os, json, webbrowser, argparse | |
| import pandas as pd | |
| from config import BASE_DIR, OUTPUT_DIR, OUTPUT_FILES, INDUSTRIES, PRICE_COLS | |
| from core.engine import load_panel, run_walk_forward | |
| from core.analysis import apply_industry_rules, generate_all_reports, evaluate_results, run_ablation | |
| from core.hedging import compute_all_industry_hedges, backtest_hedging | |
| from core.feature_selection import run_feature_funnel | |
| os.chdir(BASE_DIR) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| def run_benchmark(panel_path, benchmark, price_col, features_override=None): | |
| """对单个油价基准运行完整 walk-forward 流程。""" | |
| print(f"\n{'─'*65}") | |
| print(f" Benchmark: {benchmark} ({price_col})") | |
| print(f"{'─'*65}") | |
| # Load panel | |
| panel, features = load_panel(panel_path, price_col=price_col) | |
| # Walk-forward | |
| results, shap_data = run_walk_forward(panel, features) | |
| print(f" 测试月数: {len(results)}") | |
| # Industry rules | |
| for i, row in results.iterrows(): | |
| rules = apply_industry_rules(row) | |
| for k, v in rules.items(): | |
| results.at[i, k] = v | |
| # NLG reports | |
| reports = generate_all_reports(results) | |
| # Tag benchmark | |
| results['benchmark'] = benchmark | |
| return results, shap_data, reports, panel, features | |
| def main(skip_update=False): | |
| # ════ Step 0: Data Update ════ | |
| if not skip_update: | |
| print("═" * 65) | |
| print("Step 0: 全特征 API 数据更新") | |
| print("═" * 65) | |
| try: | |
| from pipeline.live_data import main as live_main | |
| live_main() | |
| panel_path = 'output/panel_monthly_live.csv' | |
| if os.path.exists(panel_path): | |
| print(f"✓ 使用更新后的面板: {panel_path}") | |
| else: | |
| panel_path = 'output/panel_monthly.csv' | |
| except Exception as e: | |
| print(f"⚠ 数据更新跳过: {e}") | |
| panel_path = 'output/panel_monthly.csv' | |
| else: | |
| print("跳过数据更新") | |
| if os.path.exists('output/panel_monthly_live.csv'): | |
| panel_path = 'output/panel_monthly_live.csv' | |
| else: | |
| panel_path = 'output/panel_monthly.csv' | |
| # ════ Step 1: Feature Selection Funnel ════ | |
| print("\n" + "═" * 65) | |
| print("Step 1: 特征筛选漏斗 (329→17)") | |
| print("═" * 65) | |
| funnel = run_feature_funnel('output/panel_monthly.csv') | |
| with open(OUTPUT_FILES['feat_sel'], 'w', encoding='utf-8') as f: | |
| json.dump(funnel, f, ensure_ascii=False, indent=2) | |
| print(f"✓ 特征筛选: {OUTPUT_FILES['feat_sel']}") | |
| # ════ Step 2: Walk-Forward for EACH benchmark ════ | |
| print("\n" + "═" * 65) | |
| print("Step 2: Walk-Forward 预测 (WTI + Brent)") | |
| print("═" * 65) | |
| all_results = {} | |
| all_shap = {} | |
| all_reports = {} | |
| all_panels = {} | |
| all_features = {} | |
| for benchmark, price_col in PRICE_COLS.items(): | |
| results, shap_data, reports, panel, features = run_benchmark( | |
| panel_path, benchmark, price_col) | |
| all_results[benchmark] = results | |
| all_shap[benchmark] = shap_data | |
| all_reports[benchmark] = reports | |
| all_panels[benchmark] = panel | |
| all_features[benchmark] = features | |
| # Use WTI as primary for hedging/evaluation (backward compat) | |
| primary = 'WTI' | |
| results = all_results[primary] | |
| # ════ Step 3: Hedging (based on WTI) ════ | |
| print("\n" + "═" * 65) | |
| print("Step 3: 对冲决策计算") | |
| print("═" * 65) | |
| latest = results.iloc[-1] | |
| hedging_data = compute_all_industry_hedges(latest) | |
| hedging_json = {} | |
| for ind, hd in hedging_data.items(): | |
| hedging_json[ind] = { | |
| 'industry_zh': hd['industry_zh'], | |
| 'exposure': hd['exposure'], | |
| 'elasticity': hd['elasticity'], | |
| 'recommended_ratio': hd['recommended_ratio'], | |
| 'recommended_ratio_pct': hd['recommended_ratio_pct'], | |
| 'recommended_tool': hd['recommended_tool'], | |
| 'rationale': hd['rationale'], | |
| 'matrix': hd['matrix'], | |
| 'tool_comparison': hd['tool_comparison'], | |
| } | |
| print(f" {hd['industry_zh']}: 推荐对冲 {hd['recommended_ratio_pct']}, " | |
| f"工具={hd['recommended_tool']}") | |
| with open(OUTPUT_FILES['hedging'], 'w', encoding='utf-8') as f: | |
| json.dump(hedging_json, f, ensure_ascii=False, indent=2) | |
| print(f"✓ Hedging: {OUTPUT_FILES['hedging']}") | |
| # Hedge Backtest | |
| print(" [回测对冲策略 — 过去60月]") | |
| backtest_data = backtest_hedging(results) | |
| backtest_json = {} | |
| for ind, bt in backtest_data.items(): | |
| backtest_json[ind] = bt | |
| print(f" {bt['industry_zh']}: 累计节省${bt['total_saving']:.1f}M, " | |
| f"波动率降低{bt['vol_reduction']:.0f}%, " | |
| f"最大回撤改善${bt['dd_improvement']:.1f}M") | |
| with open(OUTPUT_FILES['backtest'], 'w', encoding='utf-8') as f: | |
| json.dump(backtest_json, f, ensure_ascii=False, indent=2) | |
| print(f"✓ Backtest: {OUTPUT_FILES['backtest']}") | |
| # ════ Step 4: NLG Reports ════ | |
| print("\n" + "═" * 65) | |
| print("Step 4: NLG 报告生成") | |
| print("═" * 65) | |
| for bm, reports in all_reports.items(): | |
| print(f" {bm}: {len(reports)} 份报告") | |
| # ════ Step 5: Evaluation ════ | |
| for bm, res in all_results.items(): | |
| print(f"\n--- Evaluation: {bm} ---") | |
| evaluate_results(res) | |
| # ════ Step 6: Save ════ | |
| print("\n" + "═" * 65) | |
| print("Step 6: 保存结果") | |
| print("═" * 65) | |
| # Save per-benchmark results | |
| for bm, res in all_results.items(): | |
| out_path = os.path.join(OUTPUT_DIR, f'v2_results_{bm}.csv') | |
| res.to_csv(out_path, index=False) | |
| print(f"✓ 结果 [{bm}]: {out_path}") | |
| # Also save primary as the main results (backward compat) | |
| results.to_csv(OUTPUT_FILES['results'], index=False) | |
| print(f"✓ 结果 [primary]: {OUTPUT_FILES['results']}") | |
| # SHAP (primary) | |
| with open(OUTPUT_FILES['shap'], 'w', encoding='utf-8') as f: | |
| json.dump(all_shap[primary][-12:], f, ensure_ascii=False, indent=2) | |
| print(f"✓ SHAP: {OUTPUT_FILES['shap']}") | |
| # NLG (merge all benchmarks) | |
| merged_reports = {} | |
| for bm, reps in all_reports.items(): | |
| for dt, report in reps.items(): | |
| key = f"{dt}_{bm}" if bm != primary else dt | |
| merged_reports[key] = report | |
| # Also save per-benchmark | |
| with open(os.path.join(OUTPUT_DIR, f'v2_nlg_{bm}.json'), 'w', encoding='utf-8') as f: | |
| json.dump(reps, f, ensure_ascii=False, indent=2) | |
| with open(OUTPUT_FILES['nlg'], 'w', encoding='utf-8') as f: | |
| json.dump(merged_reports, f, ensure_ascii=False, indent=2) | |
| print(f"✓ NLG: {OUTPUT_FILES['nlg']}") | |
| # Scenarios (primary) | |
| scenario_data = {} | |
| for _, row in results.tail(12).iterrows(): | |
| dt = pd.Timestamp(row['test_date']).strftime('%Y-%m') | |
| scenario_data[dt] = { | |
| 'base': round(row['scenario_base'] * 100, 2), | |
| 'vix_shock': round(row['scenario_vix_shock'] * 100, 2), | |
| 'supply_cut': round(row['scenario_supply_cut'] * 100, 2), | |
| 'demand_crash': round(row['scenario_demand_crash'] * 100, 2), | |
| } | |
| with open(OUTPUT_FILES['scenarios'], 'w', encoding='utf-8') as f: | |
| json.dump(scenario_data, f, indent=2) | |
| print(f"✓ Scenarios: {OUTPUT_FILES['scenarios']}") | |
| # Regime (primary) | |
| regime_data = {} | |
| for _, row in results.iterrows(): | |
| dt = pd.Timestamp(row['test_date']).strftime('%Y-%m') | |
| regime_data[dt] = { | |
| 'match': row.get('regime_match', 'Unknown'), | |
| 'similarity': row.get('regime_similarity', 0), | |
| 'type': row.get('regime_type', 'normal'), | |
| } | |
| with open(OUTPUT_FILES['regime'], 'w', encoding='utf-8') as f: | |
| json.dump(regime_data, f, ensure_ascii=False, indent=2) | |
| print(f"✓ Regime: {OUTPUT_FILES['regime']}") | |
| # ════ Step 7: Ablation (primary only) ════ | |
| print("\n" + "═" * 65) | |
| print("Step 7: 消融实验") | |
| print("═" * 65) | |
| ablation_results = run_ablation(all_panels[primary], all_features[primary]) | |
| with open(OUTPUT_FILES['ablation'], 'w') as f: | |
| json.dump(ablation_results, f, indent=2) | |
| print(f"✓ Ablation: {OUTPUT_FILES['ablation']}") | |
| # ════ Step 7b: Causal Analysis ════ | |
| print("\n" + "═" * 65) | |
| print("Step 7b: 因果因子网络分析") | |
| print("═" * 65) | |
| try: | |
| from pipeline.causal_analysis import run_full_causal_analysis | |
| causal_result = run_full_causal_analysis(panel_path) | |
| print(f"✓ 因果分析: {OUTPUT_FILES.get('causal', 'output/causal_analysis.json')}") | |
| except Exception as e: | |
| print(f"⚠ 因果分析跳过: {e}") | |
| # ════ Step 8: Done ════ | |
| print("\n" + "═" * 65) | |
| print("✅ 全部完成!") | |
| print("═" * 65) | |
| print(" 启动前端: cd frontend && npm run dev") | |
| print(" 启动API: python api_server.py") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='油价风险分析平台 — 一键启动') | |
| parser.add_argument('--skip-update', action='store_true', | |
| help='跳过 FRED/EIA 数据更新,直接使用现有数据') | |
| args = parser.parse_args() | |
| main(skip_update=args.skip_update) | |