File size: 7,980 Bytes
09dd12e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import os, argparse, json, math, sys
sys.path.insert(0, '/app')
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from finjepa.model import FinJEPA, FinJEPALoss
from finjepa.data import FinancialTrajectoryDataset, build_dataloaders, generate_synthetic_data, load_hf_stock_data
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_source', type=str, default='synthetic', choices=['synthetic', 'hf'])
parser.add_argument('--dataset_name', type=str, default='paperswithbacktest/Stocks-Daily-Price')
parser.add_argument('--symbols', type=str, default=None)
parser.add_argument('--n_assets', type=int, default=1)
parser.add_argument('--context_window', type=int, default=60)
parser.add_argument('--target_window', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--embed_dim', type=int, default=128)
parser.add_argument('--encoder_depth', type=int, default=4)
parser.add_argument('--encoder_heads', type=int, default=4)
parser.add_argument('--predictor_depth', type=int, default=6)
parser.add_argument('--predictor_heads', type=int, default=4)
parser.add_argument('--patch_size', type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--ema_decay', type=float, default=0.996)
parser.add_argument('--use_idm', action='store_true', default=True)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--weight_decay', type=float, default=1e-6)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--grad_clip', type=float, default=1.0)
parser.add_argument('--rollout_steps', type=int, default=2)
parser.add_argument('--alpha', type=float, default=2.0)
parser.add_argument('--beta', type=float, default=1.0)
parser.add_argument('--delta', type=float, default=4.0)
parser.add_argument('--omega', type=float, default=0.5)
parser.add_argument('--gamma', type=float, default=0.5)
parser.add_argument('--output_dir', type=str, default='./outputs')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--push_to_hub', action='store_true')
parser.add_argument('--hub_model_id', type=str, default='ashesh8500/finjepa')
return parser.parse_args()
def set_seed(seed):
import random, numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def train_epoch(model, loss_fn, dataloader, optimizer, device, grad_clip, rollout_steps):
model.train()
total_loss = 0
total_pred = 0
total_reg = 0
total_temporal = 0
total_idm = 0
total_rollout = 0
n = 0
for batch in dataloader:
ctx = batch['context'].to(device)
tgt = batch['target'].to(device)
w = batch['weights'].to(device)
s = batch['signals'].to(device)
h = batch['hedge'].to(device)
outputs = model(ctx, tgt, w, s, h)
actions_gt = {'weights': w, 'signals': s}
rollout_outputs = []
if rollout_steps > 1:
for k in range(1, rollout_steps):
ro = model(ctx, tgt, w, s, h)
rollout_outputs.append(ro)
loss_dict = loss_fn(outputs, actions_gt, rollout_outputs)
loss = loss_dict['loss']
optimizer.zero_grad()
loss.backward()
if grad_clip > 0:
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
model.update_target()
total_loss += loss.item()
total_pred += loss_dict['loss_pred']
total_reg += loss_dict['loss_reg']
total_temporal += loss_dict['loss_temporal']
total_idm += loss_dict['loss_idm']
total_rollout += loss_dict['loss_rollout']
n += 1
return {'loss': total_loss / n, 'loss_pred': total_pred / n, 'loss_reg': total_reg / n,
'loss_temporal': total_temporal / n, 'loss_idm': total_idm / n, 'loss_rollout': total_rollout / n}
@torch.no_grad()
def evaluate(model, loss_fn, dataloader, device, rollout_steps):
model.eval()
total_loss = 0
total_pred = 0
n = 0
for batch in dataloader:
ctx = batch['context'].to(device)
tgt = batch['target'].to(device)
w = batch['weights'].to(device)
s = batch['signals'].to(device)
h = batch['hedge'].to(device)
outputs = model(ctx, tgt, w, s, h)
actions_gt = {'weights': w, 'signals': s}
rollout_outputs = []
if rollout_steps > 1:
for k in range(1, rollout_steps):
ro = model(ctx, tgt, w, s, h)
rollout_outputs.append(ro)
loss_dict = loss_fn(outputs, actions_gt, rollout_outputs)
total_loss += loss_dict['loss'].item()
total_pred += loss_dict['loss_pred']
n += 1
return {'loss': total_loss / n, 'loss_pred': total_pred / n}
def main():
args = get_args()
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(vars(args), f, indent=2)
if args.data_source == 'synthetic':
df = generate_synthetic_data(n_timesteps=10000, n_assets=args.n_assets)
else:
symbols = args.symbols.split(',') if args.symbols else None
df = load_hf_stock_data(args.dataset_name, symbols=symbols)
print(f"Data shape: {df.shape}")
loaders = build_dataloaders(df, n_assets=args.n_assets, context_window=args.context_window,
target_window=args.target_window, batch_size=args.batch_size)
n_features = 14
model = FinJEPA(in_features=n_features, n_assets=args.n_assets, patch_size=args.patch_size,
embed_dim=args.embed_dim, encoder_depth=args.encoder_depth, encoder_heads=args.encoder_heads,
predictor_depth=args.predictor_depth, predictor_heads=args.predictor_heads,
dropout=args.dropout, ema_decay=args.ema_decay, use_idm=args.use_idm).to(args.device)
loss_fn = FinJEPALoss(pred_loss='l1', alpha=args.alpha, beta=args.beta, delta=args.delta,
omega=args.omega, gamma=args.gamma).to(args.device)
optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.995), weight_decay=args.weight_decay)
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
best_val = float('inf')
for epoch in range(args.epochs):
train_metrics = train_epoch(model, loss_fn, loaders['train'], optimizer, args.device, args.grad_clip, args.rollout_steps)
val_metrics = evaluate(model, loss_fn, loaders['val'], args.device, args.rollout_steps)
print(f"Epoch {epoch+1}/{args.epochs} | train_loss={train_metrics['loss']:.4f} "
f"(pred={train_metrics['loss_pred']:.4f} reg={train_metrics['loss_reg']:.4f} "
f"temp={train_metrics['loss_temporal']:.4f} idm={train_metrics['loss_idm']:.4f} "
f"rollout={train_metrics['loss_rollout']:.4f}) | val_loss={val_metrics['loss']:.4f}")
if val_metrics['loss'] < best_val:
best_val = val_metrics['loss']
torch.save(model.state_dict(), os.path.join(args.output_dir, 'best_model.pt'))
torch.save(model.state_dict(), os.path.join(args.output_dir, 'final_model.pt'))
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(args.hub_model_id, repo_type="model", exist_ok=True)
api.upload_folder(folder_path=args.output_dir, repo_id=args.hub_model_id, repo_type="model")
print(f"Pushed to https://huggingface.co/{args.hub_model_id}")
if __name__ == '__main__':
main()
|