Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import pandas as pd | |
| import numpy as np | |
| import yfinance as yf | |
| from datetime import datetime, timedelta | |
| from src.inference import predict_horizons | |
| # --- CÁC HÀM HỖ TRỢ TÍNH TOÁN KỸ THUẬT CHO UI --- | |
| def calculate_ui_technical_indicators(df): | |
| """Tính toán nhanh các chỉ báo kỹ thuật cơ bản phục vụ cho hiển thị UI/Market Context""" | |
| # SMA | |
| df['SMA_20'] = df['Close'].rolling(window=20).mean() | |
| df['SMA_50'] = df['Close'].rolling(window=50).mean() | |
| # RSI 14 | |
| delta = df['Close'].diff() | |
| gain = delta.where(delta > 0, 0).rolling(window=14).mean() | |
| loss = -delta.where(delta < 0, 0).rolling(window=14).mean() | |
| rs = gain / loss | |
| df['RSI_14'] = 100 - (100 / (1 + rs)) | |
| # MACD | |
| ema_12 = df['Close'].ewm(span=12, adjust=False).mean() | |
| ema_26 = df['Close'].ewm(span=26, adjust=False).mean() | |
| df['MACD'] = ema_12 - ema_26 | |
| df['MACD_Signal'] = df['MACD'].ewm(span=9, adjust=False).mean() | |
| df['MACD_Hist'] = df['MACD'] - df['MACD_Signal'] | |
| # ATR 14 | |
| high_low = df['High'] - df['Low'] | |
| high_close = np.abs(df['High'] - df['Close'].shift()) | |
| low_close = np.abs(df['Low'] - df['Close'].shift()) | |
| tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) | |
| df['ATR_14'] = tr.rolling(window=14).mean() | |
| return df | |
| # --- XỬ LÝ GIAO DIỆN & LOGIC DỰ BÁO --- | |
| def generate_quant_dashboard(ticker, model_name): | |
| try: | |
| # 1. Gọi Inference Engine | |
| preds, last_close, last_date, _ = predict_horizons(ticker, model_name) | |
| # 2. Lấy dữ liệu OHLCV 90 ngày để vẽ Candlestick & tính toán Context | |
| end_dt = datetime.strptime(last_date, '%Y-%m-%d') + timedelta(days=1) | |
| start_dt = end_dt - timedelta(days=90) | |
| df_ui = yf.download(ticker, start=start_dt.strftime('%Y-%m-%d'), end=end_dt.strftime('%Y-%m-%d'), progress=False) | |
| if isinstance(df_ui.columns, pd.MultiIndex): | |
| df_ui.columns = df_ui.columns.droplevel(1) | |
| df_ui = df_ui.reset_index() | |
| df_ui = calculate_ui_technical_indicators(df_ui) | |
| last_row = df_ui.iloc[-1] | |
| atr_val = last_row['ATR_14'] | |
| rsi_val = last_row['RSI_14'] | |
| macd_h = last_row['MACD_Hist'] | |
| base_date = pd.to_datetime(last_date) | |
| dates_future = { | |
| 1: base_date + pd.offsets.BDay(1), | |
| 7: base_date + pd.offsets.BDay(7), | |
| 21: base_date + pd.offsets.BDay(21) | |
| } | |
| except Exception as e: | |
| error_html = f"""<div style='background-color:#3a1010; padding:15px; border-left: 4px solid #ff4d4d; color: #ff8080;'> | |
| <strong>[SYSTEM ERROR]</strong> Xảy ra lỗi trong quá trình inference: {str(e)}</div>""" | |
| return error_html, "", go.Figure() | |
| # 3. Phân tích Model Divergence & Risk Metrics | |
| pred_lr = preds.get("Linear Regression", {}).get("pred_return", 0) | |
| pred_svr = preds.get("SVR", {}).get("pred_return", 0) | |
| price_lr = preds.get("Linear Regression", {}).get("pred_price", last_close) | |
| price_svr = preds.get("SVR", {}).get("pred_price", last_close) | |
| consensus_html = "" | |
| target_price = 0 | |
| # Hàm tính giá trung bình giữa các model | |
| def get_avg_price(h): | |
| if model_name == "Cả Hai": | |
| return (preds[h]["Linear Regression"]["pred_price"] + preds[h]["SVR"]["pred_price"]) / 2 | |
| else: | |
| return preds[h][model_name]["pred_price"] | |
| # Hàm tính % return trung bình giữa các model | |
| def get_avg_return(h): | |
| if model_name == "Cả Hai": | |
| return (preds[h]["Linear Regression"]["pred_return"] + preds[h]["SVR"]["pred_return"]) / 2 | |
| else: | |
| return preds[h][model_name]["pred_return"] | |
| target_1d = get_avg_price(1) | |
| target_7d = get_avg_price(7) | |
| target_21d = get_avg_price(21) | |
| ret_1d = get_avg_return(1) | |
| ret_7d = get_avg_return(7) | |
| ret_21d = get_avg_return(21) | |
| # 1. LOGIC XÁC ĐỊNH TÍN HIỆU TỔNG THỂ | |
| if ret_21d > 0.02 and ret_7d > 0: | |
| signal_text = "STRONG BUY" | |
| signal_color = "#00ff00" | |
| bg_color = "rgba(0, 255, 0, 0.1)" | |
| elif ret_21d > 0: | |
| signal_text = "ACCUMULATE" | |
| signal_color = "#00cc66" | |
| bg_color = "rgba(0, 204, 102, 0.1)" | |
| elif ret_21d < -0.02 and ret_7d < 0: | |
| signal_text = "STRONG SELL" | |
| signal_color = "#ff0000" | |
| bg_color = "rgba(255, 0, 0, 0.1)" | |
| else: | |
| signal_text = "REDUCE / SELL" | |
| signal_color = "#ff6666" | |
| bg_color = "rgba(255, 102, 102, 0.1)" | |
| def fmt_ret(val): | |
| color = "#00cc00" if val > 0 else "#ff4444" # Điều chỉnh màu xanh lá cây hơi tối xuống 1 chút cho chế độ nền trắng dễ nhìn | |
| sign = "+" if val > 0 else "" | |
| return f"<span style='color: {color}; font-weight: bold;'>{sign}{val*100:.2f}%</span>" | |
| # 2. XÂY DỰNG CONSENSUS HTML | |
| consensus_html = f""" | |
| <div style="background: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); padding: 15px; border-radius: 5px; font-family: monospace; margin-bottom: 15px;"> | |
| <div style="display: flex; justify-content: space-between; align-items: center; border-bottom: 1px solid var(--border-color-primary); padding-bottom: 10px; margin-bottom: 15px;"> | |
| <h3 style="color: var(--body-text-color-subdued); margin:0; font-size:14px;">AI CONSENSUS SIGNAL</h3> | |
| <span style="background:{bg_color}; color:{signal_color}; padding: 4px 10px; border-radius: 3px; border: 1px solid {signal_color}; font-weight: bold; font-size: 14px;"> | |
| {signal_text} | |
| </span> | |
| </div> | |
| <table style="width: 100%; color: var(--body-text-color); font-size: 13px; text-align: left; border-collapse: collapse;"> | |
| <tr style="color: var(--body-text-color-subdued);"> | |
| <th style="padding-bottom: 8px;">Horizon</th> | |
| <th style="text-align: right; padding-bottom: 8px;">Target Price</th> | |
| <th style="text-align: right; padding-bottom: 8px;">Exp. Return</th> | |
| </tr> | |
| <tr style="border-bottom: 1px dashed var(--border-color-primary);"> | |
| <td style="padding: 8px 0;">T+1 (Day)</td> | |
| <td style="text-align: right; color:#d4af37; font-weight: bold;">${target_1d:.2f}</td> | |
| <td style="text-align: right;">{fmt_ret(ret_1d)}</td> | |
| </tr> | |
| <tr style="border-bottom: 1px dashed var(--border-color-primary);"> | |
| <td style="padding: 8px 0;">T+7 (Week)</td> | |
| <td style="text-align: right; color:#d48b00; font-weight: bold;">${target_7d:.2f}</td> | |
| <td style="text-align: right;">{fmt_ret(ret_7d)}</td> | |
| </tr> | |
| <tr> | |
| <td style="padding: 8px 0;">T+21 (Month)</td> | |
| <td style="text-align: right; color:#cc4400; font-weight: bold;">${target_21d:.2f}</td> | |
| <td style="text-align: right;">{fmt_ret(ret_21d)}</td> | |
| </tr> | |
| </table> | |
| </div> | |
| """ | |
| # 4. Market Context Panel | |
| context_html = f""" | |
| <div style="background: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); padding: 15px; border-radius: 5px; font-family: monospace;"> | |
| <p style="color: var(--body-text-color-subdued); margin:0 0 10px 0; font-size:12px;">LAST CLOSE: {last_date}</p> | |
| <h2 style="color: var(--body-text-color); margin:0 0 15px 0; border-bottom: 1px solid var(--border-color-primary); padding-bottom: 10px;">${last_close:.2f}</h2> | |
| <table style="width: 100%; color: var(--body-text-color-subdued); font-size: 13px;"> | |
| <tr><td style="padding: 4px 0;">Target T+1 (Day)</td><td style="text-align: right; font-weight: bold; color: #d4af37;">${target_1d:.2f}</td></tr> | |
| <tr><td style="padding: 4px 0;">Target T+7 (Week)</td><td style="text-align: right; font-weight: bold; color: #d48b00;">${target_7d:.2f}</td></tr> | |
| <tr><td style="padding: 4px 0;">Target T+21 (Month)</td><td style="text-align: right; font-weight: bold; color: #cc4400;">${target_21d:.2f}</td></tr> | |
| </table> | |
| </div> | |
| """ | |
| # 5. Vẽ biểu đồ Plotly cấp độ Institutional (ĐÃ TỐI ƯU CHO LIGHT/DARK MODE) | |
| fig = make_subplots(rows=2, cols=1, shared_xaxes=True, | |
| vertical_spacing=0.03, row_heights=[0.75, 0.25], | |
| subplot_titles=(f"{ticker} - MULTI-HORIZON PROJECTIONS", "VOLUME")) | |
| # Candlestick | |
| fig.add_trace(go.Candlestick( | |
| x=df_ui['Date'], open=df_ui['Open'], high=df_ui['High'], | |
| low=df_ui['Low'], close=df_ui['Close'], name='Price', | |
| increasing_line_color='#00cc00', decreasing_line_color='#ff3333' # Điều chỉnh màu nến tăng để không bị chói ở nền trắng | |
| ), row=1, col=1) | |
| # SMAs | |
| fig.add_trace(go.Scatter(x=df_ui['Date'], y=df_ui['SMA_20'], mode='lines', name='SMA 20', line=dict(color='#ffaa00', width=1)), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=df_ui['Date'], y=df_ui['SMA_50'], mode='lines', name='SMA 50', line=dict(color='#00bfff', width=1)), row=1, col=1) | |
| # Volume subplot | |
| colors = ['#00cc00' if row['Close'] >= row['Open'] else '#ff3333' for _, row in df_ui.iterrows()] | |
| fig.add_trace(go.Bar(x=df_ui['Date'], y=df_ui['Volume'], marker_color=colors, name='Volume'), row=2, col=1) | |
| x_future = [base_date, dates_future[1], dates_future[7], dates_future[21]] | |
| # Thêm điểm dự báo | |
| if model_name in ["Linear Regression", "Cả Hai"]: | |
| y_lr = [last_close, preds[1]["Linear Regression"]["pred_price"], | |
| preds[7]["Linear Regression"]["pred_price"], preds[21]["Linear Regression"]["pred_price"]] | |
| fig.add_trace(go.Scatter(x=x_future, y=y_lr, mode='lines+markers', name='LR Trajectory', | |
| line=dict(color='#ff00ff', dash='dot'), marker=dict(size=8, symbol='diamond')), row=1, col=1) | |
| if model_name in ["SVR", "Cả Hai"]: | |
| y_svr = [last_close, preds[1]["SVR"]["pred_price"], | |
| preds[7]["SVR"]["pred_price"], preds[21]["SVR"]["pred_price"]] | |
| fig.add_trace(go.Scatter(x=x_future, y=y_svr, mode='lines+markers', name='SVR Trajectory', | |
| line=dict(color='#00cccc', dash='dot'), marker=dict(size=8, symbol='diamond')), row=1, col=1) | |
| upper_band = [last_close, target_1d + atr_val*np.sqrt(1), target_7d + atr_val*np.sqrt(7), target_21d + atr_val*np.sqrt(21)] | |
| lower_band = [last_close, target_1d - atr_val*np.sqrt(1), target_7d - atr_val*np.sqrt(7), target_21d - atr_val*np.sqrt(21)] | |
| # Error Band (Đã đổi màu fillcolor thành xám trung tính có độ trong suốt để tương thích tốt cả nền trắng và đen) | |
| fig.add_trace(go.Scatter(x=x_future, y=upper_band, mode='lines', name='Risk Cone Upper', line=dict(color='rgba(0, 256, 0, 1)')), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=x_future, y=lower_band, mode='lines', fill='tonexty', fillcolor='rgba(128, 128, 128, 0.25)', name='Risk Cone Lower', line=dict(color='rgba(256, 0, 0, 0.6)')), row=1, col=1) | |
| # --- TỐI ƯU LAYOUT CHO DARK / LIGHT MODE --- | |
| fig.update_layout( | |
| margin=dict(l=40, r=40, t=40, b=40), | |
| xaxis_rangeslider_visible=False, | |
| legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1, font=dict(color="#888888")), # Chữ xám trung tính | |
| font=dict(family="Courier New, monospace", color="#525252"), # Cả tiêu đề và nhãn đều xám trung tính | |
| paper_bgcolor='rgba(0,0,0,0)', # Nền ngoài TRONG SUỐT | |
| plot_bgcolor='rgba(0,0,0,0)' # Nền trong TRONG SUỐT | |
| ) | |
| # Thiết lập màu lưới trung tính chống chói | |
| fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)', linecolor='rgba(128, 128, 128, 0.3)', tickfont=dict(color="#888888")) | |
| fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)', linecolor='rgba(128, 128, 128, 0.3)', tickfont=dict(color="#888888")) | |
| return consensus_html, context_html, fig | |
| # --- KIẾN TRÚC GRADIO GIAO DIỆN --- | |
| css = """ | |
| .gradio-container { max-width: 1400px !important; } | |
| #main-row { margin-top: 20px; } | |
| """ | |
| toggle_script = """ | |
| function toggle_theme() { | |
| if (document.querySelector('body').classList.contains('dark')) { | |
| document.querySelector('body').classList.remove('dark'); | |
| } else { | |
| document.querySelector('body').classList.add('dark'); | |
| } | |
| return []; | |
| } | |
| """ | |
| with gr.Blocks(title="Quant Terminal | Stock ML", css=css, theme=gr.themes.Monochrome()) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=9): | |
| gr.Markdown(""" | |
| <div style="padding: 10px 0; border-bottom: 2px solid var(--border-color-primary);"> | |
| <h1 style="margin: 0; font-family: monospace;">⚡ QUANTRONIC ML TERMINAL v2.0</h1> | |
| <p style="color: var(--body-text-color-subdued); margin: 0; font-family: monospace;">SVR & Ridge Regression Predictive Analytics Engine</p> | |
| </div> | |
| """) | |
| with gr.Column(scale=1, min_width=150): | |
| btn_theme = gr.Button("🌓 Toggle Dark/Light", size="sm") | |
| with gr.Row(elem_id="main-row"): | |
| # SIDEBAR | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("<h4 style='font-family: monospace;'>⚙️ PARAMETERS</h4>") | |
| ticker_dd = gr.Dropdown(choices=["AAPL", "MSFT", "GOOGL", "AMZN"], value="AAPL", label="Asset Ticker") | |
| model_dd = gr.Dropdown(choices=["SVR", "Linear Regression", "Cả Hai"], value="Cả Hai", label="ML Model Selection") | |
| btn_predict = gr.Button("EXECUTE INFERENCE", variant="primary", size="lg") | |
| gr.Markdown("<br><h4 style='font-family: monospace;'>📊 METRICS & RISK</h4>") | |
| consensus_panel = gr.HTML() | |
| context_panel = gr.HTML() | |
| # MAIN AREA | |
| with gr.Column(scale=3): | |
| plot_chart = gr.Plot() | |
| btn_theme.click(fn=None, inputs=None, outputs=None, js=toggle_script) | |
| btn_predict.click( | |
| fn=generate_quant_dashboard, | |
| inputs=[ticker_dd, model_dd], | |
| outputs=[consensus_panel, context_panel, plot_chart] | |
| ) | |
| demo.load( | |
| fn=generate_quant_dashboard, | |
| inputs=[ticker_dd, model_dd], | |
| outputs=[consensus_panel, context_panel, plot_chart] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |