| import datetime as dt |
| import pandas as pd |
| import torch |
| import gradio as gr |
| import requests |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| from chronos import BaseChronosPipeline |
|
|
|
|
| |
| |
| |
| _PIPELINE_CACHE = {} |
|
|
|
|
| def get_pipeline(model_id: str, device: str = "cpu"): |
| key = (model_id, device) |
| if key not in _PIPELINE_CACHE: |
| _PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained( |
| model_id, |
| device_map=device, |
| torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16, |
| ) |
| return _PIPELINE_CACHE[key] |
|
|
|
|
| |
| |
| |
| _BINANCE_INTERVAL = {"1d": "1d", "1h": "1h", "30m": "30m", "15m": "15m", "5m": "5m"} |
|
|
|
|
| def _yf_to_binance_symbol(ticker: str) -> str | None: |
| """ |
| BTC-USD -> BTCUSDT, ETH-USD -> ETHUSDT ... |
| ๊ทธ ์ธ ํ์์ None (ํ์ฌ๋ -USD ์ฝ์ธ๋ง ์ง์) |
| """ |
| t = ticker.upper().strip() |
| if t.endswith("-USD") and len(t) >= 6: |
| base = t[:-4] |
| return f"{base}USDT" |
| return None |
|
|
|
|
| def _fetch_binance_klines( |
| ticker: str, interval: str, start: str | None, end: str | None |
| ) -> pd.Series: |
| """ |
| Binance Klines (๋ฌด์ธ์ฆ) |
| https://api.binance.com/api/v3/klines |
| |
| ๋ฐํ: pandas.Series(index=datetime, values=float close) |
| """ |
| if interval not in _BINANCE_INTERVAL: |
| raise ValueError("Binance๋ ํด๋น interval์ ์ง์ํ์ง ์์ต๋๋ค.") |
|
|
| symbol = _yf_to_binance_symbol(ticker) |
| if not symbol: |
| raise ValueError( |
| "์ด ํฐ์ปค๋ Binance ์ฌ๋ณผ๋ก ๋ณํํ ์ ์์ต๋๋ค. ์: BTC-USD, ETH-USD ํํ๋ง ์ง์ํฉ๋๋ค." |
| ) |
|
|
| base = "https://api.binance.com/api/v3/klines" |
|
|
| def to_ms(s: str) -> int: |
| return int(pd.to_datetime(s).timestamp() * 1000) |
|
|
| start_ms = to_ms(start) if start else None |
| end_ms = to_ms(end) if end else None |
|
|
| rows = [] |
| cur_start = start_ms |
|
|
| while True: |
| params = { |
| "symbol": symbol, |
| "interval": _BINANCE_INTERVAL[interval], |
| "limit": 1000, |
| } |
| if cur_start is not None: |
| params["startTime"] = cur_start |
| if end_ms is not None: |
| params["endTime"] = end_ms |
|
|
| r = requests.get(base, params=params, timeout=30) |
| r.raise_for_status() |
| data = r.json() |
| if not data: |
| break |
|
|
| rows.extend(data) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| last_close_time = data[-1][6] |
| next_start = last_close_time + 1 |
| if cur_start is not None and next_start <= cur_start: |
| break |
| cur_start = next_start |
|
|
| |
| if len(data) < 1000: |
| break |
|
|
| if not rows: |
| raise ValueError("Binance์์ ๊ฐ์ ธ์จ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.") |
|
|
| df = pd.DataFrame( |
| rows, |
| columns=[ |
| "openTime", |
| "open", |
| "high", |
| "low", |
| "close", |
| "volume", |
| "closeTime", |
| "quoteAssetVolume", |
| "numTrades", |
| "takerBuyBase", |
| "takerBuyQuote", |
| "ignore", |
| ], |
| ) |
| df["ts"] = pd.to_datetime(df["closeTime"], unit="ms") |
| s = df.set_index("ts")["close"].astype(float).sort_index() |
|
|
| if start: |
| s = s[s.index >= pd.to_datetime(start)] |
| if end: |
| s = s[s.index <= pd.to_datetime(end)] |
| if s.empty: |
| raise ValueError("Binance ์๋ฆฌ์ฆ๊ฐ ๋น์ด ์์ต๋๋ค. ๊ธฐ๊ฐ/๊ฐ๊ฒฉ์ ๋ค์ ์ค์ ํด ์ฃผ์ธ์.") |
|
|
| return s |
|
|
|
|
| def load_close_series( |
| ticker: str, start: str | None, end: str | None, interval: str = "1d" |
| ) -> pd.Series: |
| """ |
| Binance ์ ์ฉ ์ข
๊ฐ ์๋ฆฌ์ฆ ๋ก๋. |
| ์
๋ ฅ: ํฐ์ปค (์: BTC-USD, ETH-USD), ์์์ผ, ์ข
๋ฃ์ผ, ๊ฐ๊ฒฉ(1d/1h/30m/15m/5m) |
| ๋ฐํ: pandas.Series (index=datetime, values=float close) |
| """ |
| ticker = ticker.strip().upper() |
|
|
| |
| _start = start or "2017-01-01" |
| _end = end or dt.date.today().isoformat() |
| try: |
| sdt = pd.to_datetime(_start) |
| edt = pd.to_datetime(_end) |
| if edt < sdt: |
| sdt, edt = edt, sdt |
| _start, _end = sdt.date().isoformat(), edt.date().isoformat() |
| except Exception: |
| |
| _start, _end = "2017-01-01", dt.date.today().isoformat() |
|
|
| |
| return _fetch_binance_klines(ticker, interval, _start, _end) |
|
|
|
|
| |
| |
| |
| def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval): |
| |
| try: |
| series = load_close_series(ticker, start_date, end_date, interval) |
| except Exception as e: |
| return None, pd.DataFrame(), f"๋ฐ์ดํฐ ๋ก๋ฉ ์ค๋ฅ (Binance): {e}" |
|
|
| |
| try: |
| pipe = get_pipeline(model_id, device) |
| except Exception as e: |
| return None, pd.DataFrame(), f"๋ชจ๋ธ ๋ก๋ฉ ์ค๋ฅ: {e}" |
|
|
| |
| try: |
| H = int(horizon) |
| if H <= 0: |
| raise ValueError("์์ธก ์คํ
H๋ 1 ์ด์์ด์ด์ผ ํฉ๋๋ค.") |
| except Exception: |
| return None, pd.DataFrame(), "์์ธก ์คํ
H๊ฐ ์ฌ๋ฐ๋ฅด์ง ์์ต๋๋ค." |
|
|
| |
| context = torch.tensor(series.values, dtype=torch.float32) |
|
|
| |
| try: |
| preds = pipe.predict(context=context, prediction_length=H)[0] |
| except Exception as e: |
| return None, pd.DataFrame(), f"์์ธก ์คํ ์ค๋ฅ: {e}" |
|
|
| q10, q50, q90 = preds[0], preds[1], preds[2] |
|
|
| |
| df_fcst = pd.DataFrame( |
| {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()}, |
| index=pd.RangeIndex(1, H + 1, name="step"), |
| ) |
|
|
| |
| freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"} |
| freq = freq_map.get(interval, "D") |
| future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:] |
|
|
| |
| fig = plt.figure(figsize=(10, 4)) |
| plt.plot(series.index, series.values, label="history") |
| plt.plot(future_index, q50.numpy(), label="forecast(q50)") |
| plt.fill_between( |
| future_index, |
| q10.numpy(), |
| q90.numpy(), |
| alpha=0.2, |
| label="q10โq90", |
| ) |
| plt.title(f"{ticker} forecast by Chronos-Bolt (Binance, {interval}, H={H})") |
| plt.legend() |
| plt.tight_layout() |
|
|
| note = "โป ๋ฐ๋ชจ ๋ชฉ์ ์
๋๋ค. ํฌ์ ํ๋จ๊ณผ ๊ฒฐ๊ณผ ์ฑ
์์ ์ ์ ์ผ๋ก ๋ณธ์ธ์๊ฒ ์์ต๋๋ค." |
| return fig, df_fcst, note |
|
|
|
|
| |
| |
| |
| with gr.Blocks(title="Chronos Crypto Forecast (Binance)") as demo: |
| gr.Markdown("# Chronos ํฌ๋ฆฝํ ์์ธก ๋ฐ๋ชจ (Binance ์ ์ฉ)") |
| gr.Markdown( |
| "ํฐ์ปค๋ `BTC-USD`, `ETH-USD` ์ฒ๋ผ ์
๋ ฅํ๋ฉด ๋ด๋ถ์์ `BTCUSDT`, `ETHUSDT`๋ก ๋ณํํด์ Binance์์ ๊ฐ๊ฒฉ์ ๊ฐ์ ธ์ต๋๋ค." |
| ) |
|
|
| with gr.Row(): |
| ticker = gr.Textbox( |
| value="BTC-USD", |
| label="ํฐ์ปค (์: BTC-USD, ETH-USD)", |
| ) |
| horizon = gr.Slider( |
| 5, |
| 365, |
| value=90, |
| step=1, |
| label="์์ธก ์คํ
H (๊ฐ๊ฒฉ ๋จ์์ ๋์ผ)", |
| ) |
|
|
| with gr.Row(): |
| start = gr.Textbox( |
| value="2017-01-01", |
| label="์์์ผ (YYYY-MM-DD)", |
| ) |
| end = gr.Textbox( |
| value=dt.date.today().isoformat(), |
| label="์ข
๋ฃ์ผ (YYYY-MM-DD)", |
| ) |
|
|
| with gr.Row(): |
| model_id = gr.Dropdown( |
| choices=[ |
| "amazon/chronos-bolt-tiny", |
| "amazon/chronos-bolt-mini", |
| "amazon/chronos-bolt-small", |
| "amazon/chronos-bolt-base", |
| ], |
| value="amazon/chronos-bolt-small", |
| label="๋ชจ๋ธ", |
| ) |
| device = gr.Dropdown( |
| choices=["cpu"], |
| value="cpu", |
| label="๋๋ฐ์ด์ค", |
| ) |
| interval = gr.Dropdown( |
| choices=["1d", "1h", "30m", "15m", "5m"], |
| value="1d", |
| label="๊ฐ๊ฒฉ (Binance interval)", |
| ) |
|
|
| btn = gr.Button("์์ธก ์คํ") |
|
|
| plot = gr.Plot(label="History + Forecast") |
| table = gr.Dataframe(label="์์ธก ๊ฒฐ๊ณผ (๋ถ์์)") |
| note = gr.Markdown() |
|
|
| btn.click( |
| fn=run_forecast, |
| inputs=[ticker, start, end, horizon, model_id, device, interval], |
| outputs=[plot, table, note], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|