File size: 7,980 Bytes
09dd12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os, argparse, json, math, sys
sys.path.insert(0, '/app')
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from finjepa.model import FinJEPA, FinJEPALoss
from finjepa.data import FinancialTrajectoryDataset, build_dataloaders, generate_synthetic_data, load_hf_stock_data

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_source', type=str, default='synthetic', choices=['synthetic', 'hf'])
    parser.add_argument('--dataset_name', type=str, default='paperswithbacktest/Stocks-Daily-Price')
    parser.add_argument('--symbols', type=str, default=None)
    parser.add_argument('--n_assets', type=int, default=1)
    parser.add_argument('--context_window', type=int, default=60)
    parser.add_argument('--target_window', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--embed_dim', type=int, default=128)
    parser.add_argument('--encoder_depth', type=int, default=4)
    parser.add_argument('--encoder_heads', type=int, default=4)
    parser.add_argument('--predictor_depth', type=int, default=6)
    parser.add_argument('--predictor_heads', type=int, default=4)
    parser.add_argument('--patch_size', type=int, default=4)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--ema_decay', type=float, default=0.996)
    parser.add_argument('--use_idm', action='store_true', default=True)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--grad_clip', type=float, default=1.0)
    parser.add_argument('--rollout_steps', type=int, default=2)
    parser.add_argument('--alpha', type=float, default=2.0)
    parser.add_argument('--beta', type=float, default=1.0)
    parser.add_argument('--delta', type=float, default=4.0)
    parser.add_argument('--omega', type=float, default=0.5)
    parser.add_argument('--gamma', type=float, default=0.5)
    parser.add_argument('--output_dir', type=str, default='./outputs')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--push_to_hub', action='store_true')
    parser.add_argument('--hub_model_id', type=str, default='ashesh8500/finjepa')
    return parser.parse_args()

def set_seed(seed):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_epoch(model, loss_fn, dataloader, optimizer, device, grad_clip, rollout_steps):
    model.train()
    total_loss = 0
    total_pred = 0
    total_reg = 0
    total_temporal = 0
    total_idm = 0
    total_rollout = 0
    n = 0
    for batch in dataloader:
        ctx = batch['context'].to(device)
        tgt = batch['target'].to(device)
        w = batch['weights'].to(device)
        s = batch['signals'].to(device)
        h = batch['hedge'].to(device)
        outputs = model(ctx, tgt, w, s, h)
        actions_gt = {'weights': w, 'signals': s}
        rollout_outputs = []
        if rollout_steps > 1:
            for k in range(1, rollout_steps):
                ro = model(ctx, tgt, w, s, h)
                rollout_outputs.append(ro)
        loss_dict = loss_fn(outputs, actions_gt, rollout_outputs)
        loss = loss_dict['loss']
        optimizer.zero_grad()
        loss.backward()
        if grad_clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        model.update_target()
        total_loss += loss.item()
        total_pred += loss_dict['loss_pred']
        total_reg += loss_dict['loss_reg']
        total_temporal += loss_dict['loss_temporal']
        total_idm += loss_dict['loss_idm']
        total_rollout += loss_dict['loss_rollout']
        n += 1
    return {'loss': total_loss / n, 'loss_pred': total_pred / n, 'loss_reg': total_reg / n,
            'loss_temporal': total_temporal / n, 'loss_idm': total_idm / n, 'loss_rollout': total_rollout / n}

@torch.no_grad()
def evaluate(model, loss_fn, dataloader, device, rollout_steps):
    model.eval()
    total_loss = 0
    total_pred = 0
    n = 0
    for batch in dataloader:
        ctx = batch['context'].to(device)
        tgt = batch['target'].to(device)
        w = batch['weights'].to(device)
        s = batch['signals'].to(device)
        h = batch['hedge'].to(device)
        outputs = model(ctx, tgt, w, s, h)
        actions_gt = {'weights': w, 'signals': s}
        rollout_outputs = []
        if rollout_steps > 1:
            for k in range(1, rollout_steps):
                ro = model(ctx, tgt, w, s, h)
                rollout_outputs.append(ro)
        loss_dict = loss_fn(outputs, actions_gt, rollout_outputs)
        total_loss += loss_dict['loss'].item()
        total_pred += loss_dict['loss_pred']
        n += 1
    return {'loss': total_loss / n, 'loss_pred': total_pred / n}

def main():
    args = get_args()
    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
        json.dump(vars(args), f, indent=2)
    if args.data_source == 'synthetic':
        df = generate_synthetic_data(n_timesteps=10000, n_assets=args.n_assets)
    else:
        symbols = args.symbols.split(',') if args.symbols else None
        df = load_hf_stock_data(args.dataset_name, symbols=symbols)
    print(f"Data shape: {df.shape}")
    loaders = build_dataloaders(df, n_assets=args.n_assets, context_window=args.context_window,
                                target_window=args.target_window, batch_size=args.batch_size)
    n_features = 14
    model = FinJEPA(in_features=n_features, n_assets=args.n_assets, patch_size=args.patch_size,
                    embed_dim=args.embed_dim, encoder_depth=args.encoder_depth, encoder_heads=args.encoder_heads,
                    predictor_depth=args.predictor_depth, predictor_heads=args.predictor_heads,
                    dropout=args.dropout, ema_decay=args.ema_decay, use_idm=args.use_idm).to(args.device)
    loss_fn = FinJEPALoss(pred_loss='l1', alpha=args.alpha, beta=args.beta, delta=args.delta,
                          omega=args.omega, gamma=args.gamma).to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.995), weight_decay=args.weight_decay)
    print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    best_val = float('inf')
    for epoch in range(args.epochs):
        train_metrics = train_epoch(model, loss_fn, loaders['train'], optimizer, args.device, args.grad_clip, args.rollout_steps)
        val_metrics = evaluate(model, loss_fn, loaders['val'], args.device, args.rollout_steps)
        print(f"Epoch {epoch+1}/{args.epochs} | train_loss={train_metrics['loss']:.4f} "
              f"(pred={train_metrics['loss_pred']:.4f} reg={train_metrics['loss_reg']:.4f} "
              f"temp={train_metrics['loss_temporal']:.4f} idm={train_metrics['loss_idm']:.4f} "
              f"rollout={train_metrics['loss_rollout']:.4f}) | val_loss={val_metrics['loss']:.4f}")
        if val_metrics['loss'] < best_val:
            best_val = val_metrics['loss']
            torch.save(model.state_dict(), os.path.join(args.output_dir, 'best_model.pt'))
    torch.save(model.state_dict(), os.path.join(args.output_dir, 'final_model.pt'))
    if args.push_to_hub:
        from huggingface_hub import HfApi
        api = HfApi()
        api.create_repo(args.hub_model_id, repo_type="model", exist_ok=True)
        api.upload_folder(folder_path=args.output_dir, repo_id=args.hub_model_id, repo_type="model")
        print(f"Pushed to https://huggingface.co/{args.hub_model_id}")

if __name__ == '__main__':
    main()