delta0790 commited on
Commit
9a681d9
·
verified ·
1 Parent(s): ba6ef65

Add training script

Browse files
Files changed (1) hide show
  1. train_sac_crypto.py +537 -0
train_sac_crypto.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAC Crypto Trading Agent - Training Script
3
+ Based on FinRL-Meta (arXiv:2304.13174) recipe:
4
+ - Dataset: linxy/CryptoCoin (Binance OHLCV) on HF Hub
5
+ - SAC hyperparams: lr=3e-4, batch=64, net_arch=[64,32], ent_coef=auto
6
+ - Technical indicators: MACD, RSI(30), CCI(30), DX(30), SMA(30), Bollinger Bands
7
+ - Reward: ΔPortfolioValue * scaling
8
+ - Commission: 0.1% (Binance spot)
9
+
10
+ Usage:
11
+ pip install stable-baselines3 gymnasium huggingface_hub pandas numpy tensorboard
12
+
13
+ python train_sac_crypto.py \
14
+ --symbol BTCUSDT \
15
+ --timeframe 1d \
16
+ --timesteps 200000 \
17
+ --lr 3e-4 \
18
+ --batch_size 64 \
19
+ --buffer_size 100000 \
20
+ --gamma 0.99 \
21
+ --tau 0.005 \
22
+ --net_arch 64 32 \
23
+ --initial_amount 100000 \
24
+ --commission 0.001 \
25
+ --max_btc 10.0 \
26
+ --reward_scaling 1e-4 \
27
+ --seed 42 \
28
+ --save_dir ./sac_crypto_model \
29
+ --push_to_hub \
30
+ --hub_model_id YOUR_USERNAME/sac-crypto-btc-agent
31
+ """
32
+
33
+ import os
34
+ import json
35
+ import numpy as np
36
+ import pandas as pd
37
+ from io import StringIO
38
+ from datetime import datetime
39
+
40
+ # ============================================================
41
+ # 1. DATA LOADING & FEATURE ENGINEERING
42
+ # ============================================================
43
+
44
+ def load_crypto_data_from_hf(symbol="BTCUSDT", timeframe="1d"):
45
+ """Load crypto OHLCV data from HF Hub dataset linxy/CryptoCoin."""
46
+ from huggingface_hub import hf_hub_download
47
+
48
+ filename = f"{symbol}_{timeframe}.csv"
49
+ print(f"Downloading {filename} from linxy/CryptoCoin...")
50
+
51
+ path = hf_hub_download(
52
+ repo_id="linxy/CryptoCoin",
53
+ filename=filename,
54
+ repo_type="dataset",
55
+ )
56
+
57
+ df = pd.read_csv(path)
58
+
59
+ # Standardize column names
60
+ col_map = {
61
+ 'Open time': 'date',
62
+ 'open': 'open',
63
+ 'high': 'high',
64
+ 'low': 'low',
65
+ 'close': 'close',
66
+ 'volume': 'volume',
67
+ }
68
+ df = df.rename(columns=col_map)
69
+
70
+ # Keep only needed columns
71
+ keep = ['date', 'open', 'high', 'low', 'close', 'volume']
72
+ df = df[[c for c in keep if c in df.columns]]
73
+
74
+ df['date'] = pd.to_datetime(df['date'])
75
+ df = df.sort_values('date').reset_index(drop=True)
76
+
77
+ # Drop NaN rows
78
+ df = df.dropna().reset_index(drop=True)
79
+
80
+ print(f"Loaded {len(df)} rows for {symbol} ({timeframe})")
81
+ print(f" Date range: {df['date'].iloc[0]} to {df['date'].iloc[-1]}")
82
+ print(f" Price range: ${df['close'].min():.2f} - ${df['close'].max():.2f}")
83
+
84
+ return df
85
+
86
+
87
+ def add_technical_indicators(df):
88
+ """
89
+ Add technical indicators following FinRL-Meta recipe:
90
+ MACD, RSI(30), CCI(30), DX(30), SMA(30), Bollinger Bands
91
+
92
+ Using pandas/numpy directly to avoid stockstats dependency issues.
93
+ """
94
+ df = df.copy()
95
+ close = df['close']
96
+ high = df['high']
97
+ low = df['low']
98
+
99
+ # --- MACD ---
100
+ ema12 = close.ewm(span=12, adjust=False).mean()
101
+ ema26 = close.ewm(span=26, adjust=False).mean()
102
+ df['macd'] = ema12 - ema26
103
+ df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
104
+ df['macd_hist'] = df['macd'] - df['macd_signal']
105
+
106
+ # --- RSI (14-period, normalized to [-1, 1]) ---
107
+ delta = close.diff()
108
+ gain = delta.where(delta > 0, 0.0)
109
+ loss = -delta.where(delta < 0, 0.0)
110
+ avg_gain = gain.rolling(window=14, min_periods=1).mean()
111
+ avg_loss = loss.rolling(window=14, min_periods=1).mean()
112
+ rs = avg_gain / (avg_loss + 1e-10)
113
+ rsi = 100 - (100 / (1 + rs))
114
+ df['rsi_30'] = (rsi - 50) / 50 # Normalize to [-1, 1]
115
+
116
+ # --- CCI (20-period) ---
117
+ typical_price = (high + low + close) / 3
118
+ sma_tp = typical_price.rolling(window=20, min_periods=1).mean()
119
+ mad = typical_price.rolling(window=20, min_periods=1).apply(
120
+ lambda x: np.abs(x - x.mean()).mean(), raw=True
121
+ )
122
+ df['cci_30'] = (typical_price - sma_tp) / (0.015 * mad + 1e-10)
123
+ df['cci_30'] = df['cci_30'] / 200 # Normalize
124
+
125
+ # --- DX (Directional Index, 14-period) ---
126
+ plus_dm = high.diff()
127
+ minus_dm = -low.diff()
128
+ plus_dm = plus_dm.where((plus_dm > minus_dm) & (plus_dm > 0), 0.0)
129
+ minus_dm = minus_dm.where((minus_dm > plus_dm) & (minus_dm > 0), 0.0)
130
+ tr = pd.concat([
131
+ high - low,
132
+ (high - close.shift(1)).abs(),
133
+ (low - close.shift(1)).abs()
134
+ ], axis=1).max(axis=1)
135
+ atr = tr.rolling(window=14, min_periods=1).mean()
136
+ plus_di = 100 * plus_dm.rolling(14, min_periods=1).mean() / (atr + 1e-10)
137
+ minus_di = 100 * minus_dm.rolling(14, min_periods=1).mean() / (atr + 1e-10)
138
+ dx = 100 * (plus_di - minus_di).abs() / (plus_di + minus_di + 1e-10)
139
+ df['dx_30'] = dx / 100 # Normalize to [0, 1]
140
+
141
+ # --- SMA (30-day) ratio ---
142
+ sma30 = close.rolling(window=30, min_periods=1).mean()
143
+ df['close_30_sma'] = (close - sma30) / (sma30 + 1e-10)
144
+
145
+ # --- Bollinger Bands (20-period, 2 std) ---
146
+ sma20 = close.rolling(window=20, min_periods=1).mean()
147
+ std20 = close.rolling(window=20, min_periods=1).std()
148
+ df['boll_ub'] = (close - (sma20 + 2 * std20)) / (close + 1e-10)
149
+ df['boll_lb'] = (close - (sma20 - 2 * std20)) / (close + 1e-10)
150
+
151
+ # --- Volume change ratio ---
152
+ df['volume_change'] = df['volume'].pct_change().fillna(0).clip(-5, 5)
153
+
154
+ # Fill NaN from rolling windows
155
+ df = df.fillna(0)
156
+
157
+ print(f"Added {len([c for c in df.columns if c not in ['date','open','high','low','close','volume']])} technical indicators")
158
+
159
+ return df
160
+
161
+
162
+ def prepare_data(symbol="BTCUSDT", timeframe="1d", train_ratio=0.7, val_ratio=0.15):
163
+ """Load data, add indicators, and split into train/val/test."""
164
+ df = load_crypto_data_from_hf(symbol, timeframe)
165
+ df = add_technical_indicators(df)
166
+
167
+ n = len(df)
168
+ train_end = int(n * train_ratio)
169
+ val_end = int(n * (train_ratio + val_ratio))
170
+
171
+ df_train = df.iloc[:train_end].reset_index(drop=True)
172
+ df_val = df.iloc[train_end:val_end].reset_index(drop=True)
173
+ df_test = df.iloc[val_end:].reset_index(drop=True)
174
+
175
+ print(f"\nData splits:")
176
+ print(f" Train: {len(df_train)} days ({df.iloc[0]['date'].date()} to {df.iloc[train_end-1]['date'].date()})")
177
+ print(f" Val: {len(df_val)} days ({df.iloc[train_end]['date'].date()} to {df.iloc[val_end-1]['date'].date()})")
178
+ print(f" Test: {len(df_test)} days ({df.iloc[val_end]['date'].date()} to {df.iloc[-1]['date'].date()})")
179
+
180
+ return df_train, df_val, df_test
181
+
182
+
183
+ # ============================================================
184
+ # 2. TRAINING
185
+ # ============================================================
186
+
187
+ def train_sac_agent(
188
+ df_train,
189
+ df_val,
190
+ total_timesteps=200_000,
191
+ learning_rate=3e-4,
192
+ batch_size=64,
193
+ buffer_size=100_000,
194
+ gamma=0.99,
195
+ tau=0.005,
196
+ net_arch=(64, 32),
197
+ initial_amount=100_000.0,
198
+ commission=0.001,
199
+ max_btc=10.0,
200
+ reward_scaling=1e-4,
201
+ seed=42,
202
+ save_dir="./sac_crypto_model",
203
+ ):
204
+ """Train SAC agent on crypto trading environment."""
205
+ from stable_baselines3 import SAC
206
+ from stable_baselines3.common.env_checker import check_env
207
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
208
+ from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
209
+ from crypto_trading_env import SingleAssetTradingEnv
210
+
211
+ print("\n" + "="*60)
212
+ print("TRAINING SAC CRYPTO AGENT")
213
+ print("="*60)
214
+ print(f" Timesteps: {total_timesteps:,}")
215
+ print(f" LR: {learning_rate}, Batch: {batch_size}")
216
+ print(f" Net arch: {list(net_arch)}")
217
+ print(f" Buffer: {buffer_size:,}, Gamma: {gamma}, Tau: {tau}")
218
+ print(f" Initial amount: ${initial_amount:,.0f}")
219
+ print(f" Commission: {commission*100:.1f}%")
220
+ print("="*60)
221
+
222
+ # Create environments
223
+ tech_cols = ['macd', 'macd_hist', 'rsi_30', 'cci_30', 'dx_30',
224
+ 'close_30_sma', 'boll_ub', 'boll_lb', 'volume_change']
225
+
226
+ def make_train_env():
227
+ return SingleAssetTradingEnv(
228
+ df=df_train,
229
+ initial_amount=initial_amount,
230
+ commission_rate=commission,
231
+ reward_scaling=reward_scaling,
232
+ max_btc=max_btc,
233
+ )
234
+
235
+ def make_val_env():
236
+ return SingleAssetTradingEnv(
237
+ df=df_val,
238
+ initial_amount=initial_amount,
239
+ commission_rate=commission,
240
+ reward_scaling=reward_scaling,
241
+ max_btc=max_btc,
242
+ )
243
+
244
+ # Verify environment
245
+ test_env = make_train_env()
246
+ check_env(test_env, warn=True)
247
+ print("✓ Environment passed check_env validation")
248
+ del test_env
249
+
250
+ # Vectorized environments
251
+ train_env = DummyVecEnv([make_train_env])
252
+ val_env = DummyVecEnv([make_val_env])
253
+
254
+ # Normalize observations (not reward - we handle reward scaling ourselves)
255
+ train_env = VecNormalize(train_env, norm_obs=True, norm_reward=False,
256
+ clip_obs=10.0, gamma=gamma)
257
+ val_env = VecNormalize(val_env, norm_obs=True, norm_reward=False,
258
+ clip_obs=10.0, training=False, gamma=gamma)
259
+
260
+ # Custom callback for logging
261
+ class TradingCallback(BaseCallback):
262
+ def __init__(self, verbose=0):
263
+ super().__init__(verbose)
264
+ self.episode_returns = []
265
+
266
+ def _on_step(self) -> bool:
267
+ # Log every 10000 steps
268
+ if self.n_calls % 10000 == 0:
269
+ # Get infos from the environment
270
+ if hasattr(self.training_env, 'get_attr'):
271
+ try:
272
+ envs = self.training_env.get_attr('portfolio_values')
273
+ if envs and len(envs[0]) > 1:
274
+ pv = envs[0][-1]
275
+ ret = (pv - initial_amount) / initial_amount * 100
276
+ print(f" Step {self.n_calls:>8,}: Portfolio ${pv:,.0f} ({ret:+.1f}%)")
277
+ except:
278
+ pass
279
+ return True
280
+
281
+ # SAC model (FinRL-Contest recipe)
282
+ model = SAC(
283
+ policy="MlpPolicy",
284
+ env=train_env,
285
+ learning_rate=learning_rate,
286
+ batch_size=batch_size,
287
+ buffer_size=buffer_size,
288
+ learning_starts=max(1000, batch_size * 4),
289
+ gamma=gamma,
290
+ tau=tau,
291
+ ent_coef="auto", # Auto-tune entropy (key SAC feature)
292
+ target_entropy="auto",
293
+ train_freq=1,
294
+ gradient_steps=1,
295
+ policy_kwargs=dict(net_arch=list(net_arch)),
296
+ verbose=1,
297
+ seed=seed,
298
+ tensorboard_log="./logs/sac_crypto/",
299
+ )
300
+
301
+ print(f"\nModel parameters: {sum(p.numel() for p in model.policy.parameters()):,}")
302
+
303
+ # Eval callback
304
+ os.makedirs(save_dir, exist_ok=True)
305
+ eval_callback = EvalCallback(
306
+ val_env,
307
+ best_model_save_path=save_dir,
308
+ log_path=save_dir,
309
+ eval_freq=max(5000, total_timesteps // 20),
310
+ n_eval_episodes=1,
311
+ deterministic=True,
312
+ verbose=1,
313
+ )
314
+
315
+ trading_callback = TradingCallback()
316
+
317
+ # Train
318
+ print("\nStarting training...")
319
+ model.learn(
320
+ total_timesteps=total_timesteps,
321
+ callback=[eval_callback, trading_callback],
322
+ progress_bar=False,
323
+ )
324
+
325
+ # Save final model
326
+ final_path = os.path.join(save_dir, "sac_crypto_final")
327
+ model.save(final_path)
328
+ train_env.save(os.path.join(save_dir, "vec_normalize.pkl"))
329
+
330
+ print(f"\n✓ Model saved to {final_path}")
331
+
332
+ return model, train_env
333
+
334
+
335
+ # ============================================================
336
+ # 3. EVALUATION & BACKTESTING
337
+ # ============================================================
338
+
339
+ def evaluate_agent(model, df_test, train_env, initial_amount=100_000.0,
340
+ commission=0.001, max_btc=10.0, reward_scaling=1e-4):
341
+ """Backtest trained agent on test data."""
342
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
343
+ from crypto_trading_env import SingleAssetTradingEnv
344
+
345
+ print("\n" + "="*60)
346
+ print("BACKTESTING ON TEST DATA")
347
+ print("="*60)
348
+
349
+ # Create test environment
350
+ test_env_raw = SingleAssetTradingEnv(
351
+ df=df_test,
352
+ initial_amount=initial_amount,
353
+ commission_rate=commission,
354
+ reward_scaling=reward_scaling,
355
+ max_btc=max_btc,
356
+ )
357
+
358
+ # Run agent
359
+ obs, _ = test_env_raw.reset()
360
+
361
+ portfolio_values = [initial_amount]
362
+ actions_taken = []
363
+ done = False
364
+
365
+ while not done:
366
+ action, _ = model.predict(obs, deterministic=True)
367
+ obs, reward, terminated, truncated, info = test_env_raw.step(action)
368
+ done = terminated or truncated
369
+ portfolio_values.append(info['portfolio_value'])
370
+ actions_taken.append(float(action[0]))
371
+
372
+ # Calculate metrics
373
+ portfolio_values = np.array(portfolio_values)
374
+
375
+ # Total return
376
+ total_return = (portfolio_values[-1] - initial_amount) / initial_amount * 100
377
+
378
+ # Daily returns
379
+ daily_returns = np.diff(portfolio_values) / portfolio_values[:-1]
380
+
381
+ # Sharpe ratio (annualized, assuming 365 trading days for crypto)
382
+ if len(daily_returns) > 1 and np.std(daily_returns) > 0:
383
+ sharpe = np.sqrt(365) * np.mean(daily_returns) / np.std(daily_returns)
384
+ else:
385
+ sharpe = 0.0
386
+
387
+ # Max drawdown
388
+ peak = np.maximum.accumulate(portfolio_values)
389
+ drawdown = (peak - portfolio_values) / peak
390
+ max_drawdown = np.max(drawdown) * 100
391
+
392
+ # Sortino ratio
393
+ downside = daily_returns[daily_returns < 0]
394
+ if len(downside) > 0:
395
+ sortino = np.sqrt(365) * np.mean(daily_returns) / np.std(downside)
396
+ else:
397
+ sortino = float('inf')
398
+
399
+ # Buy & Hold comparison
400
+ bh_return = (df_test['close'].iloc[-1] - df_test['close'].iloc[0]) / df_test['close'].iloc[0] * 100
401
+ bh_values = initial_amount * df_test['close'].values / df_test['close'].iloc[0]
402
+ bh_daily_returns = np.diff(bh_values) / bh_values[:-1]
403
+ if len(bh_daily_returns) > 1 and np.std(bh_daily_returns) > 0:
404
+ bh_sharpe = np.sqrt(365) * np.mean(bh_daily_returns) / np.std(bh_daily_returns)
405
+ else:
406
+ bh_sharpe = 0.0
407
+ bh_peak = np.maximum.accumulate(bh_values)
408
+ bh_dd = np.max((bh_peak - bh_values) / bh_peak) * 100
409
+
410
+ # Action statistics
411
+ actions_arr = np.array(actions_taken)
412
+ n_buy = np.sum(actions_arr > 0.1)
413
+ n_sell = np.sum(actions_arr < -0.1)
414
+ n_hold = len(actions_arr) - n_buy - n_sell
415
+
416
+ print(f"\n{'Metric':<25} {'SAC Agent':>15} {'Buy & Hold':>15}")
417
+ print("-" * 57)
418
+ print(f"{'Total Return':<25} {total_return:>14.2f}% {bh_return:>14.2f}%")
419
+ print(f"{'Sharpe Ratio':<25} {sharpe:>15.3f} {bh_sharpe:>15.3f}")
420
+ print(f"{'Sortino Ratio':<25} {sortino:>15.3f} {'N/A':>15}")
421
+ print(f"{'Max Drawdown':<25} {max_drawdown:>14.2f}% {bh_dd:>14.2f}%")
422
+ print(f"{'Final Portfolio':<25} ${portfolio_values[-1]:>13,.0f} ${bh_values[-1]:>13,.0f}")
423
+ print(f"\nActions: {n_buy} buys, {n_sell} sells, {n_hold} holds")
424
+ print(f"Mean action: {actions_arr.mean():.4f}, Std: {actions_arr.std():.4f}")
425
+
426
+ results = {
427
+ "total_return_pct": round(total_return, 2),
428
+ "sharpe_ratio": round(sharpe, 3),
429
+ "sortino_ratio": round(sortino, 3),
430
+ "max_drawdown_pct": round(max_drawdown, 2),
431
+ "final_portfolio": round(portfolio_values[-1], 2),
432
+ "buy_hold_return_pct": round(bh_return, 2),
433
+ "buy_hold_sharpe": round(bh_sharpe, 3),
434
+ "n_trades_buy": int(n_buy),
435
+ "n_trades_sell": int(n_sell),
436
+ "test_days": len(df_test),
437
+ }
438
+
439
+ return results, portfolio_values, actions_taken
440
+
441
+
442
+ # ============================================================
443
+ # 4. MAIN
444
+ # ============================================================
445
+
446
+ def main():
447
+ import argparse
448
+
449
+ parser = argparse.ArgumentParser(description="SAC Crypto Trading Agent")
450
+ parser.add_argument("--symbol", default="BTCUSDT", help="Trading pair")
451
+ parser.add_argument("--timeframe", default="1d", help="Candle timeframe")
452
+ parser.add_argument("--timesteps", type=int, default=200_000, help="Total training timesteps")
453
+ parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
454
+ parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
455
+ parser.add_argument("--buffer_size", type=int, default=100_000, help="Replay buffer size")
456
+ parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
457
+ parser.add_argument("--tau", type=float, default=0.005, help="Target network update rate")
458
+ parser.add_argument("--net_arch", type=int, nargs="+", default=[64, 32], help="Network architecture")
459
+ parser.add_argument("--initial_amount", type=float, default=100_000.0, help="Starting capital")
460
+ parser.add_argument("--commission", type=float, default=0.001, help="Trading commission rate")
461
+ parser.add_argument("--max_btc", type=float, default=10.0, help="Max BTC per trade")
462
+ parser.add_argument("--reward_scaling", type=float, default=1e-4, help="Reward scaling factor")
463
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
464
+ parser.add_argument("--save_dir", default="./sac_crypto_model", help="Model save directory")
465
+ parser.add_argument("--push_to_hub", action="store_true", help="Push model to HF Hub")
466
+ parser.add_argument("--hub_model_id", default=None, help="HF Hub model ID")
467
+
468
+ args = parser.parse_args()
469
+
470
+ # Load and prepare data
471
+ print("=" * 60)
472
+ print("SAC CRYPTO TRADING AGENT")
473
+ print(f"Symbol: {args.symbol}, Timeframe: {args.timeframe}")
474
+ print(f"Training timesteps: {args.timesteps:,}")
475
+ print("=" * 60)
476
+
477
+ df_train, df_val, df_test = prepare_data(
478
+ symbol=args.symbol,
479
+ timeframe=args.timeframe,
480
+ )
481
+
482
+ # Train
483
+ model, train_env = train_sac_agent(
484
+ df_train=df_train,
485
+ df_val=df_val,
486
+ total_timesteps=args.timesteps,
487
+ learning_rate=args.lr,
488
+ batch_size=args.batch_size,
489
+ buffer_size=args.buffer_size,
490
+ gamma=args.gamma,
491
+ tau=args.tau,
492
+ net_arch=tuple(args.net_arch),
493
+ initial_amount=args.initial_amount,
494
+ commission=args.commission,
495
+ max_btc=args.max_btc,
496
+ reward_scaling=args.reward_scaling,
497
+ seed=args.seed,
498
+ save_dir=args.save_dir,
499
+ )
500
+
501
+ # Evaluate
502
+ results, portfolio_values, actions = evaluate_agent(
503
+ model=model,
504
+ df_test=df_test,
505
+ train_env=train_env,
506
+ initial_amount=args.initial_amount,
507
+ commission=args.commission,
508
+ max_btc=args.max_btc,
509
+ reward_scaling=args.reward_scaling,
510
+ )
511
+
512
+ # Save results
513
+ results_path = os.path.join(args.save_dir, "results.json")
514
+ with open(results_path, 'w') as f:
515
+ json.dump(results, f, indent=2)
516
+ print(f"\n✓ Results saved to {results_path}")
517
+
518
+ # Push to Hub
519
+ if args.push_to_hub and args.hub_model_id:
520
+ try:
521
+ from huggingface_hub import HfApi
522
+ api = HfApi()
523
+ api.create_repo(args.hub_model_id, exist_ok=True)
524
+ api.upload_folder(
525
+ folder_path=args.save_dir,
526
+ repo_id=args.hub_model_id,
527
+ commit_message=f"SAC crypto agent - {args.symbol} - Sharpe {results['sharpe_ratio']}"
528
+ )
529
+ print(f"\n✓ Model pushed to https://huggingface.co/{args.hub_model_id}")
530
+ except Exception as e:
531
+ print(f"⚠ Failed to push to hub: {e}")
532
+
533
+ return results
534
+
535
+
536
+ if __name__ == "__main__":
537
+ main()