Reality8081's picture
Update src
22259b4
import gradio as gr
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
import yfinance as yf
from datetime import datetime, timedelta
from src.inference import predict_horizons
# --- CÁC HÀM HỖ TRỢ TÍNH TOÁN KỸ THUẬT CHO UI ---
def calculate_ui_technical_indicators(df):
"""Tính toán nhanh các chỉ báo kỹ thuật cơ bản phục vụ cho hiển thị UI/Market Context"""
# SMA
df['SMA_20'] = df['Close'].rolling(window=20).mean()
df['SMA_50'] = df['Close'].rolling(window=50).mean()
# RSI 14
delta = df['Close'].diff()
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
rs = gain / loss
df['RSI_14'] = 100 - (100 / (1 + rs))
# MACD
ema_12 = df['Close'].ewm(span=12, adjust=False).mean()
ema_26 = df['Close'].ewm(span=26, adjust=False).mean()
df['MACD'] = ema_12 - ema_26
df['MACD_Signal'] = df['MACD'].ewm(span=9, adjust=False).mean()
df['MACD_Hist'] = df['MACD'] - df['MACD_Signal']
# ATR 14
high_low = df['High'] - df['Low']
high_close = np.abs(df['High'] - df['Close'].shift())
low_close = np.abs(df['Low'] - df['Close'].shift())
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['ATR_14'] = tr.rolling(window=14).mean()
return df
# --- XỬ LÝ GIAO DIỆN & LOGIC DỰ BÁO ---
def generate_quant_dashboard(ticker, model_name):
try:
# 1. Gọi Inference Engine
preds, last_close, last_date, _ = predict_horizons(ticker, model_name)
# 2. Lấy dữ liệu OHLCV 90 ngày để vẽ Candlestick & tính toán Context
end_dt = datetime.strptime(last_date, '%Y-%m-%d') + timedelta(days=1)
start_dt = end_dt - timedelta(days=90)
df_ui = yf.download(ticker, start=start_dt.strftime('%Y-%m-%d'), end=end_dt.strftime('%Y-%m-%d'), progress=False)
if isinstance(df_ui.columns, pd.MultiIndex):
df_ui.columns = df_ui.columns.droplevel(1)
df_ui = df_ui.reset_index()
df_ui = calculate_ui_technical_indicators(df_ui)
last_row = df_ui.iloc[-1]
atr_val = last_row['ATR_14']
rsi_val = last_row['RSI_14']
macd_h = last_row['MACD_Hist']
base_date = pd.to_datetime(last_date)
dates_future = {
1: base_date + pd.offsets.BDay(1),
7: base_date + pd.offsets.BDay(7),
21: base_date + pd.offsets.BDay(21)
}
except Exception as e:
error_html = f"""<div style='background-color:#3a1010; padding:15px; border-left: 4px solid #ff4d4d; color: #ff8080;'>
<strong>[SYSTEM ERROR]</strong> Xảy ra lỗi trong quá trình inference: {str(e)}</div>"""
return error_html, "", go.Figure()
# 3. Phân tích Model Divergence & Risk Metrics
pred_lr = preds.get("Linear Regression", {}).get("pred_return", 0)
pred_svr = preds.get("SVR", {}).get("pred_return", 0)
price_lr = preds.get("Linear Regression", {}).get("pred_price", last_close)
price_svr = preds.get("SVR", {}).get("pred_price", last_close)
consensus_html = ""
target_price = 0
# Hàm tính giá trung bình giữa các model
def get_avg_price(h):
if model_name == "Cả Hai":
return (preds[h]["Linear Regression"]["pred_price"] + preds[h]["SVR"]["pred_price"]) / 2
else:
return preds[h][model_name]["pred_price"]
# Hàm tính % return trung bình giữa các model
def get_avg_return(h):
if model_name == "Cả Hai":
return (preds[h]["Linear Regression"]["pred_return"] + preds[h]["SVR"]["pred_return"]) / 2
else:
return preds[h][model_name]["pred_return"]
target_1d = get_avg_price(1)
target_7d = get_avg_price(7)
target_21d = get_avg_price(21)
ret_1d = get_avg_return(1)
ret_7d = get_avg_return(7)
ret_21d = get_avg_return(21)
# 1. LOGIC XÁC ĐỊNH TÍN HIỆU TỔNG THỂ
if ret_21d > 0.02 and ret_7d > 0:
signal_text = "STRONG BUY"
signal_color = "#00ff00"
bg_color = "rgba(0, 255, 0, 0.1)"
elif ret_21d > 0:
signal_text = "ACCUMULATE"
signal_color = "#00cc66"
bg_color = "rgba(0, 204, 102, 0.1)"
elif ret_21d < -0.02 and ret_7d < 0:
signal_text = "STRONG SELL"
signal_color = "#ff0000"
bg_color = "rgba(255, 0, 0, 0.1)"
else:
signal_text = "REDUCE / SELL"
signal_color = "#ff6666"
bg_color = "rgba(255, 102, 102, 0.1)"
def fmt_ret(val):
color = "#00cc00" if val > 0 else "#ff4444" # Điều chỉnh màu xanh lá cây hơi tối xuống 1 chút cho chế độ nền trắng dễ nhìn
sign = "+" if val > 0 else ""
return f"<span style='color: {color}; font-weight: bold;'>{sign}{val*100:.2f}%</span>"
# 2. XÂY DỰNG CONSENSUS HTML
consensus_html = f"""
<div style="background: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); padding: 15px; border-radius: 5px; font-family: monospace; margin-bottom: 15px;">
<div style="display: flex; justify-content: space-between; align-items: center; border-bottom: 1px solid var(--border-color-primary); padding-bottom: 10px; margin-bottom: 15px;">
<h3 style="color: var(--body-text-color-subdued); margin:0; font-size:14px;">AI CONSENSUS SIGNAL</h3>
<span style="background:{bg_color}; color:{signal_color}; padding: 4px 10px; border-radius: 3px; border: 1px solid {signal_color}; font-weight: bold; font-size: 14px;">
{signal_text}
</span>
</div>
<table style="width: 100%; color: var(--body-text-color); font-size: 13px; text-align: left; border-collapse: collapse;">
<tr style="color: var(--body-text-color-subdued);">
<th style="padding-bottom: 8px;">Horizon</th>
<th style="text-align: right; padding-bottom: 8px;">Target Price</th>
<th style="text-align: right; padding-bottom: 8px;">Exp. Return</th>
</tr>
<tr style="border-bottom: 1px dashed var(--border-color-primary);">
<td style="padding: 8px 0;">T+1 (Day)</td>
<td style="text-align: right; color:#d4af37; font-weight: bold;">${target_1d:.2f}</td>
<td style="text-align: right;">{fmt_ret(ret_1d)}</td>
</tr>
<tr style="border-bottom: 1px dashed var(--border-color-primary);">
<td style="padding: 8px 0;">T+7 (Week)</td>
<td style="text-align: right; color:#d48b00; font-weight: bold;">${target_7d:.2f}</td>
<td style="text-align: right;">{fmt_ret(ret_7d)}</td>
</tr>
<tr>
<td style="padding: 8px 0;">T+21 (Month)</td>
<td style="text-align: right; color:#cc4400; font-weight: bold;">${target_21d:.2f}</td>
<td style="text-align: right;">{fmt_ret(ret_21d)}</td>
</tr>
</table>
</div>
"""
# 4. Market Context Panel
context_html = f"""
<div style="background: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); padding: 15px; border-radius: 5px; font-family: monospace;">
<p style="color: var(--body-text-color-subdued); margin:0 0 10px 0; font-size:12px;">LAST CLOSE: {last_date}</p>
<h2 style="color: var(--body-text-color); margin:0 0 15px 0; border-bottom: 1px solid var(--border-color-primary); padding-bottom: 10px;">${last_close:.2f}</h2>
<table style="width: 100%; color: var(--body-text-color-subdued); font-size: 13px;">
<tr><td style="padding: 4px 0;">Target T+1 (Day)</td><td style="text-align: right; font-weight: bold; color: #d4af37;">${target_1d:.2f}</td></tr>
<tr><td style="padding: 4px 0;">Target T+7 (Week)</td><td style="text-align: right; font-weight: bold; color: #d48b00;">${target_7d:.2f}</td></tr>
<tr><td style="padding: 4px 0;">Target T+21 (Month)</td><td style="text-align: right; font-weight: bold; color: #cc4400;">${target_21d:.2f}</td></tr>
</table>
</div>
"""
# 5. Vẽ biểu đồ Plotly cấp độ Institutional (ĐÃ TỐI ƯU CHO LIGHT/DARK MODE)
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
vertical_spacing=0.03, row_heights=[0.75, 0.25],
subplot_titles=(f"{ticker} - MULTI-HORIZON PROJECTIONS", "VOLUME"))
# Candlestick
fig.add_trace(go.Candlestick(
x=df_ui['Date'], open=df_ui['Open'], high=df_ui['High'],
low=df_ui['Low'], close=df_ui['Close'], name='Price',
increasing_line_color='#00cc00', decreasing_line_color='#ff3333' # Điều chỉnh màu nến tăng để không bị chói ở nền trắng
), row=1, col=1)
# SMAs
fig.add_trace(go.Scatter(x=df_ui['Date'], y=df_ui['SMA_20'], mode='lines', name='SMA 20', line=dict(color='#ffaa00', width=1)), row=1, col=1)
fig.add_trace(go.Scatter(x=df_ui['Date'], y=df_ui['SMA_50'], mode='lines', name='SMA 50', line=dict(color='#00bfff', width=1)), row=1, col=1)
# Volume subplot
colors = ['#00cc00' if row['Close'] >= row['Open'] else '#ff3333' for _, row in df_ui.iterrows()]
fig.add_trace(go.Bar(x=df_ui['Date'], y=df_ui['Volume'], marker_color=colors, name='Volume'), row=2, col=1)
x_future = [base_date, dates_future[1], dates_future[7], dates_future[21]]
# Thêm điểm dự báo
if model_name in ["Linear Regression", "Cả Hai"]:
y_lr = [last_close, preds[1]["Linear Regression"]["pred_price"],
preds[7]["Linear Regression"]["pred_price"], preds[21]["Linear Regression"]["pred_price"]]
fig.add_trace(go.Scatter(x=x_future, y=y_lr, mode='lines+markers', name='LR Trajectory',
line=dict(color='#ff00ff', dash='dot'), marker=dict(size=8, symbol='diamond')), row=1, col=1)
if model_name in ["SVR", "Cả Hai"]:
y_svr = [last_close, preds[1]["SVR"]["pred_price"],
preds[7]["SVR"]["pred_price"], preds[21]["SVR"]["pred_price"]]
fig.add_trace(go.Scatter(x=x_future, y=y_svr, mode='lines+markers', name='SVR Trajectory',
line=dict(color='#00cccc', dash='dot'), marker=dict(size=8, symbol='diamond')), row=1, col=1)
upper_band = [last_close, target_1d + atr_val*np.sqrt(1), target_7d + atr_val*np.sqrt(7), target_21d + atr_val*np.sqrt(21)]
lower_band = [last_close, target_1d - atr_val*np.sqrt(1), target_7d - atr_val*np.sqrt(7), target_21d - atr_val*np.sqrt(21)]
# Error Band (Đã đổi màu fillcolor thành xám trung tính có độ trong suốt để tương thích tốt cả nền trắng và đen)
fig.add_trace(go.Scatter(x=x_future, y=upper_band, mode='lines', name='Risk Cone Upper', line=dict(color='rgba(0, 256, 0, 1)')), row=1, col=1)
fig.add_trace(go.Scatter(x=x_future, y=lower_band, mode='lines', fill='tonexty', fillcolor='rgba(128, 128, 128, 0.25)', name='Risk Cone Lower', line=dict(color='rgba(256, 0, 0, 0.6)')), row=1, col=1)
# --- TỐI ƯU LAYOUT CHO DARK / LIGHT MODE ---
fig.update_layout(
margin=dict(l=40, r=40, t=40, b=40),
xaxis_rangeslider_visible=False,
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1, font=dict(color="#888888")), # Chữ xám trung tính
font=dict(family="Courier New, monospace", color="#525252"), # Cả tiêu đề và nhãn đều xám trung tính
paper_bgcolor='rgba(0,0,0,0)', # Nền ngoài TRONG SUỐT
plot_bgcolor='rgba(0,0,0,0)' # Nền trong TRONG SUỐT
)
# Thiết lập màu lưới trung tính chống chói
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)', linecolor='rgba(128, 128, 128, 0.3)', tickfont=dict(color="#888888"))
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)', linecolor='rgba(128, 128, 128, 0.3)', tickfont=dict(color="#888888"))
return consensus_html, context_html, fig
# --- KIẾN TRÚC GRADIO GIAO DIỆN ---
css = """
.gradio-container { max-width: 1400px !important; }
#main-row { margin-top: 20px; }
"""
toggle_script = """
function toggle_theme() {
if (document.querySelector('body').classList.contains('dark')) {
document.querySelector('body').classList.remove('dark');
} else {
document.querySelector('body').classList.add('dark');
}
return [];
}
"""
with gr.Blocks(title="Quant Terminal | Stock ML", css=css, theme=gr.themes.Monochrome()) as demo:
with gr.Row():
with gr.Column(scale=9):
gr.Markdown("""
<div style="padding: 10px 0; border-bottom: 2px solid var(--border-color-primary);">
<h1 style="margin: 0; font-family: monospace;">⚡ QUANTRONIC ML TERMINAL v2.0</h1>
<p style="color: var(--body-text-color-subdued); margin: 0; font-family: monospace;">SVR & Ridge Regression Predictive Analytics Engine</p>
</div>
""")
with gr.Column(scale=1, min_width=150):
btn_theme = gr.Button("🌓 Toggle Dark/Light", size="sm")
with gr.Row(elem_id="main-row"):
# SIDEBAR
with gr.Column(scale=1, min_width=300):
gr.Markdown("<h4 style='font-family: monospace;'>⚙️ PARAMETERS</h4>")
ticker_dd = gr.Dropdown(choices=["AAPL", "MSFT", "GOOGL", "AMZN"], value="AAPL", label="Asset Ticker")
model_dd = gr.Dropdown(choices=["SVR", "Linear Regression", "Cả Hai"], value="Cả Hai", label="ML Model Selection")
btn_predict = gr.Button("EXECUTE INFERENCE", variant="primary", size="lg")
gr.Markdown("<br><h4 style='font-family: monospace;'>📊 METRICS & RISK</h4>")
consensus_panel = gr.HTML()
context_panel = gr.HTML()
# MAIN AREA
with gr.Column(scale=3):
plot_chart = gr.Plot()
btn_theme.click(fn=None, inputs=None, outputs=None, js=toggle_script)
btn_predict.click(
fn=generate_quant_dashboard,
inputs=[ticker_dd, model_dd],
outputs=[consensus_panel, context_panel, plot_chart]
)
demo.load(
fn=generate_quant_dashboard,
inputs=[ticker_dd, model_dd],
outputs=[consensus_panel, context_panel, plot_chart]
)
if __name__ == "__main__":
demo.launch()