| import os, argparse, json, math, sys |
| sys.path.insert(0, '/app') |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from torch.optim import AdamW |
| from finjepa.model import FinJEPA, FinJEPALoss |
| from finjepa.data import FinancialTrajectoryDataset, build_dataloaders, generate_synthetic_data, load_hf_stock_data |
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--data_source', type=str, default='synthetic', choices=['synthetic', 'hf']) |
| parser.add_argument('--dataset_name', type=str, default='paperswithbacktest/Stocks-Daily-Price') |
| parser.add_argument('--symbols', type=str, default=None) |
| parser.add_argument('--n_assets', type=int, default=1) |
| parser.add_argument('--context_window', type=int, default=60) |
| parser.add_argument('--target_window', type=int, default=5) |
| parser.add_argument('--batch_size', type=int, default=128) |
| parser.add_argument('--embed_dim', type=int, default=128) |
| parser.add_argument('--encoder_depth', type=int, default=4) |
| parser.add_argument('--encoder_heads', type=int, default=4) |
| parser.add_argument('--predictor_depth', type=int, default=6) |
| parser.add_argument('--predictor_heads', type=int, default=4) |
| parser.add_argument('--patch_size', type=int, default=4) |
| parser.add_argument('--dropout', type=float, default=0.0) |
| parser.add_argument('--ema_decay', type=float, default=0.996) |
| parser.add_argument('--use_idm', action='store_true', default=True) |
| parser.add_argument('--lr', type=float, default=0.001) |
| parser.add_argument('--weight_decay', type=float, default=1e-6) |
| parser.add_argument('--epochs', type=int, default=50) |
| parser.add_argument('--grad_clip', type=float, default=1.0) |
| parser.add_argument('--rollout_steps', type=int, default=2) |
| parser.add_argument('--alpha', type=float, default=2.0) |
| parser.add_argument('--beta', type=float, default=1.0) |
| parser.add_argument('--delta', type=float, default=4.0) |
| parser.add_argument('--omega', type=float, default=0.5) |
| parser.add_argument('--gamma', type=float, default=0.5) |
| parser.add_argument('--output_dir', type=str, default='./outputs') |
| parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--push_to_hub', action='store_true') |
| parser.add_argument('--hub_model_id', type=str, default='ashesh8500/finjepa') |
| return parser.parse_args() |
|
|
| def set_seed(seed): |
| import random, numpy as np |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| def train_epoch(model, loss_fn, dataloader, optimizer, device, grad_clip, rollout_steps): |
| model.train() |
| total_loss = 0 |
| total_pred = 0 |
| total_reg = 0 |
| total_temporal = 0 |
| total_idm = 0 |
| total_rollout = 0 |
| n = 0 |
| for batch in dataloader: |
| ctx = batch['context'].to(device) |
| tgt = batch['target'].to(device) |
| w = batch['weights'].to(device) |
| s = batch['signals'].to(device) |
| h = batch['hedge'].to(device) |
| outputs = model(ctx, tgt, w, s, h) |
| actions_gt = {'weights': w, 'signals': s} |
| rollout_outputs = [] |
| if rollout_steps > 1: |
| for k in range(1, rollout_steps): |
| ro = model(ctx, tgt, w, s, h) |
| rollout_outputs.append(ro) |
| loss_dict = loss_fn(outputs, actions_gt, rollout_outputs) |
| loss = loss_dict['loss'] |
| optimizer.zero_grad() |
| loss.backward() |
| if grad_clip > 0: |
| nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| optimizer.step() |
| model.update_target() |
| total_loss += loss.item() |
| total_pred += loss_dict['loss_pred'] |
| total_reg += loss_dict['loss_reg'] |
| total_temporal += loss_dict['loss_temporal'] |
| total_idm += loss_dict['loss_idm'] |
| total_rollout += loss_dict['loss_rollout'] |
| n += 1 |
| return {'loss': total_loss / n, 'loss_pred': total_pred / n, 'loss_reg': total_reg / n, |
| 'loss_temporal': total_temporal / n, 'loss_idm': total_idm / n, 'loss_rollout': total_rollout / n} |
|
|
| @torch.no_grad() |
| def evaluate(model, loss_fn, dataloader, device, rollout_steps): |
| model.eval() |
| total_loss = 0 |
| total_pred = 0 |
| n = 0 |
| for batch in dataloader: |
| ctx = batch['context'].to(device) |
| tgt = batch['target'].to(device) |
| w = batch['weights'].to(device) |
| s = batch['signals'].to(device) |
| h = batch['hedge'].to(device) |
| outputs = model(ctx, tgt, w, s, h) |
| actions_gt = {'weights': w, 'signals': s} |
| rollout_outputs = [] |
| if rollout_steps > 1: |
| for k in range(1, rollout_steps): |
| ro = model(ctx, tgt, w, s, h) |
| rollout_outputs.append(ro) |
| loss_dict = loss_fn(outputs, actions_gt, rollout_outputs) |
| total_loss += loss_dict['loss'].item() |
| total_pred += loss_dict['loss_pred'] |
| n += 1 |
| return {'loss': total_loss / n, 'loss_pred': total_pred / n} |
|
|
| def main(): |
| args = get_args() |
| set_seed(args.seed) |
| os.makedirs(args.output_dir, exist_ok=True) |
| with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: |
| json.dump(vars(args), f, indent=2) |
| if args.data_source == 'synthetic': |
| df = generate_synthetic_data(n_timesteps=10000, n_assets=args.n_assets) |
| else: |
| symbols = args.symbols.split(',') if args.symbols else None |
| df = load_hf_stock_data(args.dataset_name, symbols=symbols) |
| print(f"Data shape: {df.shape}") |
| loaders = build_dataloaders(df, n_assets=args.n_assets, context_window=args.context_window, |
| target_window=args.target_window, batch_size=args.batch_size) |
| n_features = 14 |
| model = FinJEPA(in_features=n_features, n_assets=args.n_assets, patch_size=args.patch_size, |
| embed_dim=args.embed_dim, encoder_depth=args.encoder_depth, encoder_heads=args.encoder_heads, |
| predictor_depth=args.predictor_depth, predictor_heads=args.predictor_heads, |
| dropout=args.dropout, ema_decay=args.ema_decay, use_idm=args.use_idm).to(args.device) |
| loss_fn = FinJEPALoss(pred_loss='l1', alpha=args.alpha, beta=args.beta, delta=args.delta, |
| omega=args.omega, gamma=args.gamma).to(args.device) |
| optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.995), weight_decay=args.weight_decay) |
| print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") |
| best_val = float('inf') |
| for epoch in range(args.epochs): |
| train_metrics = train_epoch(model, loss_fn, loaders['train'], optimizer, args.device, args.grad_clip, args.rollout_steps) |
| val_metrics = evaluate(model, loss_fn, loaders['val'], args.device, args.rollout_steps) |
| print(f"Epoch {epoch+1}/{args.epochs} | train_loss={train_metrics['loss']:.4f} " |
| f"(pred={train_metrics['loss_pred']:.4f} reg={train_metrics['loss_reg']:.4f} " |
| f"temp={train_metrics['loss_temporal']:.4f} idm={train_metrics['loss_idm']:.4f} " |
| f"rollout={train_metrics['loss_rollout']:.4f}) | val_loss={val_metrics['loss']:.4f}") |
| if val_metrics['loss'] < best_val: |
| best_val = val_metrics['loss'] |
| torch.save(model.state_dict(), os.path.join(args.output_dir, 'best_model.pt')) |
| torch.save(model.state_dict(), os.path.join(args.output_dir, 'final_model.pt')) |
| if args.push_to_hub: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.create_repo(args.hub_model_id, repo_type="model", exist_ok=True) |
| api.upload_folder(folder_path=args.output_dir, repo_id=args.hub_model_id, repo_type="model") |
| print(f"Pushed to https://huggingface.co/{args.hub_model_id}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|