sac-crypto-btc-agent / train_sac_crypto.py
delta0790's picture
Add training script
9a681d9 verified
"""
SAC Crypto Trading Agent - Training Script
Based on FinRL-Meta (arXiv:2304.13174) recipe:
- Dataset: linxy/CryptoCoin (Binance OHLCV) on HF Hub
- SAC hyperparams: lr=3e-4, batch=64, net_arch=[64,32], ent_coef=auto
- Technical indicators: MACD, RSI(30), CCI(30), DX(30), SMA(30), Bollinger Bands
- Reward: ΔPortfolioValue * scaling
- Commission: 0.1% (Binance spot)
Usage:
pip install stable-baselines3 gymnasium huggingface_hub pandas numpy tensorboard
python train_sac_crypto.py \
--symbol BTCUSDT \
--timeframe 1d \
--timesteps 200000 \
--lr 3e-4 \
--batch_size 64 \
--buffer_size 100000 \
--gamma 0.99 \
--tau 0.005 \
--net_arch 64 32 \
--initial_amount 100000 \
--commission 0.001 \
--max_btc 10.0 \
--reward_scaling 1e-4 \
--seed 42 \
--save_dir ./sac_crypto_model \
--push_to_hub \
--hub_model_id YOUR_USERNAME/sac-crypto-btc-agent
"""
import os
import json
import numpy as np
import pandas as pd
from io import StringIO
from datetime import datetime
# ============================================================
# 1. DATA LOADING & FEATURE ENGINEERING
# ============================================================
def load_crypto_data_from_hf(symbol="BTCUSDT", timeframe="1d"):
"""Load crypto OHLCV data from HF Hub dataset linxy/CryptoCoin."""
from huggingface_hub import hf_hub_download
filename = f"{symbol}_{timeframe}.csv"
print(f"Downloading {filename} from linxy/CryptoCoin...")
path = hf_hub_download(
repo_id="linxy/CryptoCoin",
filename=filename,
repo_type="dataset",
)
df = pd.read_csv(path)
# Standardize column names
col_map = {
'Open time': 'date',
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'volume': 'volume',
}
df = df.rename(columns=col_map)
# Keep only needed columns
keep = ['date', 'open', 'high', 'low', 'close', 'volume']
df = df[[c for c in keep if c in df.columns]]
df['date'] = pd.to_datetime(df['date'])
df = df.sort_values('date').reset_index(drop=True)
# Drop NaN rows
df = df.dropna().reset_index(drop=True)
print(f"Loaded {len(df)} rows for {symbol} ({timeframe})")
print(f" Date range: {df['date'].iloc[0]} to {df['date'].iloc[-1]}")
print(f" Price range: ${df['close'].min():.2f} - ${df['close'].max():.2f}")
return df
def add_technical_indicators(df):
"""
Add technical indicators following FinRL-Meta recipe:
MACD, RSI(30), CCI(30), DX(30), SMA(30), Bollinger Bands
Using pandas/numpy directly to avoid stockstats dependency issues.
"""
df = df.copy()
close = df['close']
high = df['high']
low = df['low']
# --- MACD ---
ema12 = close.ewm(span=12, adjust=False).mean()
ema26 = close.ewm(span=26, adjust=False).mean()
df['macd'] = ema12 - ema26
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
df['macd_hist'] = df['macd'] - df['macd_signal']
# --- RSI (14-period, normalized to [-1, 1]) ---
delta = close.diff()
gain = delta.where(delta > 0, 0.0)
loss = -delta.where(delta < 0, 0.0)
avg_gain = gain.rolling(window=14, min_periods=1).mean()
avg_loss = loss.rolling(window=14, min_periods=1).mean()
rs = avg_gain / (avg_loss + 1e-10)
rsi = 100 - (100 / (1 + rs))
df['rsi_30'] = (rsi - 50) / 50 # Normalize to [-1, 1]
# --- CCI (20-period) ---
typical_price = (high + low + close) / 3
sma_tp = typical_price.rolling(window=20, min_periods=1).mean()
mad = typical_price.rolling(window=20, min_periods=1).apply(
lambda x: np.abs(x - x.mean()).mean(), raw=True
)
df['cci_30'] = (typical_price - sma_tp) / (0.015 * mad + 1e-10)
df['cci_30'] = df['cci_30'] / 200 # Normalize
# --- DX (Directional Index, 14-period) ---
plus_dm = high.diff()
minus_dm = -low.diff()
plus_dm = plus_dm.where((plus_dm > minus_dm) & (plus_dm > 0), 0.0)
minus_dm = minus_dm.where((minus_dm > plus_dm) & (minus_dm > 0), 0.0)
tr = pd.concat([
high - low,
(high - close.shift(1)).abs(),
(low - close.shift(1)).abs()
], axis=1).max(axis=1)
atr = tr.rolling(window=14, min_periods=1).mean()
plus_di = 100 * plus_dm.rolling(14, min_periods=1).mean() / (atr + 1e-10)
minus_di = 100 * minus_dm.rolling(14, min_periods=1).mean() / (atr + 1e-10)
dx = 100 * (plus_di - minus_di).abs() / (plus_di + minus_di + 1e-10)
df['dx_30'] = dx / 100 # Normalize to [0, 1]
# --- SMA (30-day) ratio ---
sma30 = close.rolling(window=30, min_periods=1).mean()
df['close_30_sma'] = (close - sma30) / (sma30 + 1e-10)
# --- Bollinger Bands (20-period, 2 std) ---
sma20 = close.rolling(window=20, min_periods=1).mean()
std20 = close.rolling(window=20, min_periods=1).std()
df['boll_ub'] = (close - (sma20 + 2 * std20)) / (close + 1e-10)
df['boll_lb'] = (close - (sma20 - 2 * std20)) / (close + 1e-10)
# --- Volume change ratio ---
df['volume_change'] = df['volume'].pct_change().fillna(0).clip(-5, 5)
# Fill NaN from rolling windows
df = df.fillna(0)
print(f"Added {len([c for c in df.columns if c not in ['date','open','high','low','close','volume']])} technical indicators")
return df
def prepare_data(symbol="BTCUSDT", timeframe="1d", train_ratio=0.7, val_ratio=0.15):
"""Load data, add indicators, and split into train/val/test."""
df = load_crypto_data_from_hf(symbol, timeframe)
df = add_technical_indicators(df)
n = len(df)
train_end = int(n * train_ratio)
val_end = int(n * (train_ratio + val_ratio))
df_train = df.iloc[:train_end].reset_index(drop=True)
df_val = df.iloc[train_end:val_end].reset_index(drop=True)
df_test = df.iloc[val_end:].reset_index(drop=True)
print(f"\nData splits:")
print(f" Train: {len(df_train)} days ({df.iloc[0]['date'].date()} to {df.iloc[train_end-1]['date'].date()})")
print(f" Val: {len(df_val)} days ({df.iloc[train_end]['date'].date()} to {df.iloc[val_end-1]['date'].date()})")
print(f" Test: {len(df_test)} days ({df.iloc[val_end]['date'].date()} to {df.iloc[-1]['date'].date()})")
return df_train, df_val, df_test
# ============================================================
# 2. TRAINING
# ============================================================
def train_sac_agent(
df_train,
df_val,
total_timesteps=200_000,
learning_rate=3e-4,
batch_size=64,
buffer_size=100_000,
gamma=0.99,
tau=0.005,
net_arch=(64, 32),
initial_amount=100_000.0,
commission=0.001,
max_btc=10.0,
reward_scaling=1e-4,
seed=42,
save_dir="./sac_crypto_model",
):
"""Train SAC agent on crypto trading environment."""
from stable_baselines3 import SAC
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from crypto_trading_env import SingleAssetTradingEnv
print("\n" + "="*60)
print("TRAINING SAC CRYPTO AGENT")
print("="*60)
print(f" Timesteps: {total_timesteps:,}")
print(f" LR: {learning_rate}, Batch: {batch_size}")
print(f" Net arch: {list(net_arch)}")
print(f" Buffer: {buffer_size:,}, Gamma: {gamma}, Tau: {tau}")
print(f" Initial amount: ${initial_amount:,.0f}")
print(f" Commission: {commission*100:.1f}%")
print("="*60)
# Create environments
tech_cols = ['macd', 'macd_hist', 'rsi_30', 'cci_30', 'dx_30',
'close_30_sma', 'boll_ub', 'boll_lb', 'volume_change']
def make_train_env():
return SingleAssetTradingEnv(
df=df_train,
initial_amount=initial_amount,
commission_rate=commission,
reward_scaling=reward_scaling,
max_btc=max_btc,
)
def make_val_env():
return SingleAssetTradingEnv(
df=df_val,
initial_amount=initial_amount,
commission_rate=commission,
reward_scaling=reward_scaling,
max_btc=max_btc,
)
# Verify environment
test_env = make_train_env()
check_env(test_env, warn=True)
print("✓ Environment passed check_env validation")
del test_env
# Vectorized environments
train_env = DummyVecEnv([make_train_env])
val_env = DummyVecEnv([make_val_env])
# Normalize observations (not reward - we handle reward scaling ourselves)
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=False,
clip_obs=10.0, gamma=gamma)
val_env = VecNormalize(val_env, norm_obs=True, norm_reward=False,
clip_obs=10.0, training=False, gamma=gamma)
# Custom callback for logging
class TradingCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
self.episode_returns = []
def _on_step(self) -> bool:
# Log every 10000 steps
if self.n_calls % 10000 == 0:
# Get infos from the environment
if hasattr(self.training_env, 'get_attr'):
try:
envs = self.training_env.get_attr('portfolio_values')
if envs and len(envs[0]) > 1:
pv = envs[0][-1]
ret = (pv - initial_amount) / initial_amount * 100
print(f" Step {self.n_calls:>8,}: Portfolio ${pv:,.0f} ({ret:+.1f}%)")
except:
pass
return True
# SAC model (FinRL-Contest recipe)
model = SAC(
policy="MlpPolicy",
env=train_env,
learning_rate=learning_rate,
batch_size=batch_size,
buffer_size=buffer_size,
learning_starts=max(1000, batch_size * 4),
gamma=gamma,
tau=tau,
ent_coef="auto", # Auto-tune entropy (key SAC feature)
target_entropy="auto",
train_freq=1,
gradient_steps=1,
policy_kwargs=dict(net_arch=list(net_arch)),
verbose=1,
seed=seed,
tensorboard_log="./logs/sac_crypto/",
)
print(f"\nModel parameters: {sum(p.numel() for p in model.policy.parameters()):,}")
# Eval callback
os.makedirs(save_dir, exist_ok=True)
eval_callback = EvalCallback(
val_env,
best_model_save_path=save_dir,
log_path=save_dir,
eval_freq=max(5000, total_timesteps // 20),
n_eval_episodes=1,
deterministic=True,
verbose=1,
)
trading_callback = TradingCallback()
# Train
print("\nStarting training...")
model.learn(
total_timesteps=total_timesteps,
callback=[eval_callback, trading_callback],
progress_bar=False,
)
# Save final model
final_path = os.path.join(save_dir, "sac_crypto_final")
model.save(final_path)
train_env.save(os.path.join(save_dir, "vec_normalize.pkl"))
print(f"\n✓ Model saved to {final_path}")
return model, train_env
# ============================================================
# 3. EVALUATION & BACKTESTING
# ============================================================
def evaluate_agent(model, df_test, train_env, initial_amount=100_000.0,
commission=0.001, max_btc=10.0, reward_scaling=1e-4):
"""Backtest trained agent on test data."""
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from crypto_trading_env import SingleAssetTradingEnv
print("\n" + "="*60)
print("BACKTESTING ON TEST DATA")
print("="*60)
# Create test environment
test_env_raw = SingleAssetTradingEnv(
df=df_test,
initial_amount=initial_amount,
commission_rate=commission,
reward_scaling=reward_scaling,
max_btc=max_btc,
)
# Run agent
obs, _ = test_env_raw.reset()
portfolio_values = [initial_amount]
actions_taken = []
done = False
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = test_env_raw.step(action)
done = terminated or truncated
portfolio_values.append(info['portfolio_value'])
actions_taken.append(float(action[0]))
# Calculate metrics
portfolio_values = np.array(portfolio_values)
# Total return
total_return = (portfolio_values[-1] - initial_amount) / initial_amount * 100
# Daily returns
daily_returns = np.diff(portfolio_values) / portfolio_values[:-1]
# Sharpe ratio (annualized, assuming 365 trading days for crypto)
if len(daily_returns) > 1 and np.std(daily_returns) > 0:
sharpe = np.sqrt(365) * np.mean(daily_returns) / np.std(daily_returns)
else:
sharpe = 0.0
# Max drawdown
peak = np.maximum.accumulate(portfolio_values)
drawdown = (peak - portfolio_values) / peak
max_drawdown = np.max(drawdown) * 100
# Sortino ratio
downside = daily_returns[daily_returns < 0]
if len(downside) > 0:
sortino = np.sqrt(365) * np.mean(daily_returns) / np.std(downside)
else:
sortino = float('inf')
# Buy & Hold comparison
bh_return = (df_test['close'].iloc[-1] - df_test['close'].iloc[0]) / df_test['close'].iloc[0] * 100
bh_values = initial_amount * df_test['close'].values / df_test['close'].iloc[0]
bh_daily_returns = np.diff(bh_values) / bh_values[:-1]
if len(bh_daily_returns) > 1 and np.std(bh_daily_returns) > 0:
bh_sharpe = np.sqrt(365) * np.mean(bh_daily_returns) / np.std(bh_daily_returns)
else:
bh_sharpe = 0.0
bh_peak = np.maximum.accumulate(bh_values)
bh_dd = np.max((bh_peak - bh_values) / bh_peak) * 100
# Action statistics
actions_arr = np.array(actions_taken)
n_buy = np.sum(actions_arr > 0.1)
n_sell = np.sum(actions_arr < -0.1)
n_hold = len(actions_arr) - n_buy - n_sell
print(f"\n{'Metric':<25} {'SAC Agent':>15} {'Buy & Hold':>15}")
print("-" * 57)
print(f"{'Total Return':<25} {total_return:>14.2f}% {bh_return:>14.2f}%")
print(f"{'Sharpe Ratio':<25} {sharpe:>15.3f} {bh_sharpe:>15.3f}")
print(f"{'Sortino Ratio':<25} {sortino:>15.3f} {'N/A':>15}")
print(f"{'Max Drawdown':<25} {max_drawdown:>14.2f}% {bh_dd:>14.2f}%")
print(f"{'Final Portfolio':<25} ${portfolio_values[-1]:>13,.0f} ${bh_values[-1]:>13,.0f}")
print(f"\nActions: {n_buy} buys, {n_sell} sells, {n_hold} holds")
print(f"Mean action: {actions_arr.mean():.4f}, Std: {actions_arr.std():.4f}")
results = {
"total_return_pct": round(total_return, 2),
"sharpe_ratio": round(sharpe, 3),
"sortino_ratio": round(sortino, 3),
"max_drawdown_pct": round(max_drawdown, 2),
"final_portfolio": round(portfolio_values[-1], 2),
"buy_hold_return_pct": round(bh_return, 2),
"buy_hold_sharpe": round(bh_sharpe, 3),
"n_trades_buy": int(n_buy),
"n_trades_sell": int(n_sell),
"test_days": len(df_test),
}
return results, portfolio_values, actions_taken
# ============================================================
# 4. MAIN
# ============================================================
def main():
import argparse
parser = argparse.ArgumentParser(description="SAC Crypto Trading Agent")
parser.add_argument("--symbol", default="BTCUSDT", help="Trading pair")
parser.add_argument("--timeframe", default="1d", help="Candle timeframe")
parser.add_argument("--timesteps", type=int, default=200_000, help="Total training timesteps")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--buffer_size", type=int, default=100_000, help="Replay buffer size")
parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
parser.add_argument("--tau", type=float, default=0.005, help="Target network update rate")
parser.add_argument("--net_arch", type=int, nargs="+", default=[64, 32], help="Network architecture")
parser.add_argument("--initial_amount", type=float, default=100_000.0, help="Starting capital")
parser.add_argument("--commission", type=float, default=0.001, help="Trading commission rate")
parser.add_argument("--max_btc", type=float, default=10.0, help="Max BTC per trade")
parser.add_argument("--reward_scaling", type=float, default=1e-4, help="Reward scaling factor")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--save_dir", default="./sac_crypto_model", help="Model save directory")
parser.add_argument("--push_to_hub", action="store_true", help="Push model to HF Hub")
parser.add_argument("--hub_model_id", default=None, help="HF Hub model ID")
args = parser.parse_args()
# Load and prepare data
print("=" * 60)
print("SAC CRYPTO TRADING AGENT")
print(f"Symbol: {args.symbol}, Timeframe: {args.timeframe}")
print(f"Training timesteps: {args.timesteps:,}")
print("=" * 60)
df_train, df_val, df_test = prepare_data(
symbol=args.symbol,
timeframe=args.timeframe,
)
# Train
model, train_env = train_sac_agent(
df_train=df_train,
df_val=df_val,
total_timesteps=args.timesteps,
learning_rate=args.lr,
batch_size=args.batch_size,
buffer_size=args.buffer_size,
gamma=args.gamma,
tau=args.tau,
net_arch=tuple(args.net_arch),
initial_amount=args.initial_amount,
commission=args.commission,
max_btc=args.max_btc,
reward_scaling=args.reward_scaling,
seed=args.seed,
save_dir=args.save_dir,
)
# Evaluate
results, portfolio_values, actions = evaluate_agent(
model=model,
df_test=df_test,
train_env=train_env,
initial_amount=args.initial_amount,
commission=args.commission,
max_btc=args.max_btc,
reward_scaling=args.reward_scaling,
)
# Save results
results_path = os.path.join(args.save_dir, "results.json")
with open(results_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\n✓ Results saved to {results_path}")
# Push to Hub
if args.push_to_hub and args.hub_model_id:
try:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(args.hub_model_id, exist_ok=True)
api.upload_folder(
folder_path=args.save_dir,
repo_id=args.hub_model_id,
commit_message=f"SAC crypto agent - {args.symbol} - Sharpe {results['sharpe_ratio']}"
)
print(f"\n✓ Model pushed to https://huggingface.co/{args.hub_model_id}")
except Exception as e:
print(f"⚠ Failed to push to hub: {e}")
return results
if __name__ == "__main__":
main()