ashesh8500 commited on
Commit
09dd12e
·
verified ·
1 Parent(s): 9ed614d

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +161 -0
train.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, argparse, json, math, sys
2
+ sys.path.insert(0, '/app')
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader
6
+ from torch.optim import AdamW
7
+ from finjepa.model import FinJEPA, FinJEPALoss
8
+ from finjepa.data import FinancialTrajectoryDataset, build_dataloaders, generate_synthetic_data, load_hf_stock_data
9
+
10
+ def get_args():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--data_source', type=str, default='synthetic', choices=['synthetic', 'hf'])
13
+ parser.add_argument('--dataset_name', type=str, default='paperswithbacktest/Stocks-Daily-Price')
14
+ parser.add_argument('--symbols', type=str, default=None)
15
+ parser.add_argument('--n_assets', type=int, default=1)
16
+ parser.add_argument('--context_window', type=int, default=60)
17
+ parser.add_argument('--target_window', type=int, default=5)
18
+ parser.add_argument('--batch_size', type=int, default=128)
19
+ parser.add_argument('--embed_dim', type=int, default=128)
20
+ parser.add_argument('--encoder_depth', type=int, default=4)
21
+ parser.add_argument('--encoder_heads', type=int, default=4)
22
+ parser.add_argument('--predictor_depth', type=int, default=6)
23
+ parser.add_argument('--predictor_heads', type=int, default=4)
24
+ parser.add_argument('--patch_size', type=int, default=4)
25
+ parser.add_argument('--dropout', type=float, default=0.0)
26
+ parser.add_argument('--ema_decay', type=float, default=0.996)
27
+ parser.add_argument('--use_idm', action='store_true', default=True)
28
+ parser.add_argument('--lr', type=float, default=0.001)
29
+ parser.add_argument('--weight_decay', type=float, default=1e-6)
30
+ parser.add_argument('--epochs', type=int, default=50)
31
+ parser.add_argument('--grad_clip', type=float, default=1.0)
32
+ parser.add_argument('--rollout_steps', type=int, default=2)
33
+ parser.add_argument('--alpha', type=float, default=2.0)
34
+ parser.add_argument('--beta', type=float, default=1.0)
35
+ parser.add_argument('--delta', type=float, default=4.0)
36
+ parser.add_argument('--omega', type=float, default=0.5)
37
+ parser.add_argument('--gamma', type=float, default=0.5)
38
+ parser.add_argument('--output_dir', type=str, default='./outputs')
39
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
40
+ parser.add_argument('--seed', type=int, default=42)
41
+ parser.add_argument('--push_to_hub', action='store_true')
42
+ parser.add_argument('--hub_model_id', type=str, default='ashesh8500/finjepa')
43
+ return parser.parse_args()
44
+
45
+ def set_seed(seed):
46
+ import random, numpy as np
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+ torch.manual_seed(seed)
50
+ if torch.cuda.is_available():
51
+ torch.cuda.manual_seed_all(seed)
52
+
53
+ def train_epoch(model, loss_fn, dataloader, optimizer, device, grad_clip, rollout_steps):
54
+ model.train()
55
+ total_loss = 0
56
+ total_pred = 0
57
+ total_reg = 0
58
+ total_temporal = 0
59
+ total_idm = 0
60
+ total_rollout = 0
61
+ n = 0
62
+ for batch in dataloader:
63
+ ctx = batch['context'].to(device)
64
+ tgt = batch['target'].to(device)
65
+ w = batch['weights'].to(device)
66
+ s = batch['signals'].to(device)
67
+ h = batch['hedge'].to(device)
68
+ outputs = model(ctx, tgt, w, s, h)
69
+ actions_gt = {'weights': w, 'signals': s}
70
+ rollout_outputs = []
71
+ if rollout_steps > 1:
72
+ for k in range(1, rollout_steps):
73
+ ro = model(ctx, tgt, w, s, h)
74
+ rollout_outputs.append(ro)
75
+ loss_dict = loss_fn(outputs, actions_gt, rollout_outputs)
76
+ loss = loss_dict['loss']
77
+ optimizer.zero_grad()
78
+ loss.backward()
79
+ if grad_clip > 0:
80
+ nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
81
+ optimizer.step()
82
+ model.update_target()
83
+ total_loss += loss.item()
84
+ total_pred += loss_dict['loss_pred']
85
+ total_reg += loss_dict['loss_reg']
86
+ total_temporal += loss_dict['loss_temporal']
87
+ total_idm += loss_dict['loss_idm']
88
+ total_rollout += loss_dict['loss_rollout']
89
+ n += 1
90
+ return {'loss': total_loss / n, 'loss_pred': total_pred / n, 'loss_reg': total_reg / n,
91
+ 'loss_temporal': total_temporal / n, 'loss_idm': total_idm / n, 'loss_rollout': total_rollout / n}
92
+
93
+ @torch.no_grad()
94
+ def evaluate(model, loss_fn, dataloader, device, rollout_steps):
95
+ model.eval()
96
+ total_loss = 0
97
+ total_pred = 0
98
+ n = 0
99
+ for batch in dataloader:
100
+ ctx = batch['context'].to(device)
101
+ tgt = batch['target'].to(device)
102
+ w = batch['weights'].to(device)
103
+ s = batch['signals'].to(device)
104
+ h = batch['hedge'].to(device)
105
+ outputs = model(ctx, tgt, w, s, h)
106
+ actions_gt = {'weights': w, 'signals': s}
107
+ rollout_outputs = []
108
+ if rollout_steps > 1:
109
+ for k in range(1, rollout_steps):
110
+ ro = model(ctx, tgt, w, s, h)
111
+ rollout_outputs.append(ro)
112
+ loss_dict = loss_fn(outputs, actions_gt, rollout_outputs)
113
+ total_loss += loss_dict['loss'].item()
114
+ total_pred += loss_dict['loss_pred']
115
+ n += 1
116
+ return {'loss': total_loss / n, 'loss_pred': total_pred / n}
117
+
118
+ def main():
119
+ args = get_args()
120
+ set_seed(args.seed)
121
+ os.makedirs(args.output_dir, exist_ok=True)
122
+ with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
123
+ json.dump(vars(args), f, indent=2)
124
+ if args.data_source == 'synthetic':
125
+ df = generate_synthetic_data(n_timesteps=10000, n_assets=args.n_assets)
126
+ else:
127
+ symbols = args.symbols.split(',') if args.symbols else None
128
+ df = load_hf_stock_data(args.dataset_name, symbols=symbols)
129
+ print(f"Data shape: {df.shape}")
130
+ loaders = build_dataloaders(df, n_assets=args.n_assets, context_window=args.context_window,
131
+ target_window=args.target_window, batch_size=args.batch_size)
132
+ n_features = 14
133
+ model = FinJEPA(in_features=n_features, n_assets=args.n_assets, patch_size=args.patch_size,
134
+ embed_dim=args.embed_dim, encoder_depth=args.encoder_depth, encoder_heads=args.encoder_heads,
135
+ predictor_depth=args.predictor_depth, predictor_heads=args.predictor_heads,
136
+ dropout=args.dropout, ema_decay=args.ema_decay, use_idm=args.use_idm).to(args.device)
137
+ loss_fn = FinJEPALoss(pred_loss='l1', alpha=args.alpha, beta=args.beta, delta=args.delta,
138
+ omega=args.omega, gamma=args.gamma).to(args.device)
139
+ optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.995), weight_decay=args.weight_decay)
140
+ print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
141
+ best_val = float('inf')
142
+ for epoch in range(args.epochs):
143
+ train_metrics = train_epoch(model, loss_fn, loaders['train'], optimizer, args.device, args.grad_clip, args.rollout_steps)
144
+ val_metrics = evaluate(model, loss_fn, loaders['val'], args.device, args.rollout_steps)
145
+ print(f"Epoch {epoch+1}/{args.epochs} | train_loss={train_metrics['loss']:.4f} "
146
+ f"(pred={train_metrics['loss_pred']:.4f} reg={train_metrics['loss_reg']:.4f} "
147
+ f"temp={train_metrics['loss_temporal']:.4f} idm={train_metrics['loss_idm']:.4f} "
148
+ f"rollout={train_metrics['loss_rollout']:.4f}) | val_loss={val_metrics['loss']:.4f}")
149
+ if val_metrics['loss'] < best_val:
150
+ best_val = val_metrics['loss']
151
+ torch.save(model.state_dict(), os.path.join(args.output_dir, 'best_model.pt'))
152
+ torch.save(model.state_dict(), os.path.join(args.output_dir, 'final_model.pt'))
153
+ if args.push_to_hub:
154
+ from huggingface_hub import HfApi
155
+ api = HfApi()
156
+ api.create_repo(args.hub_model_id, repo_type="model", exist_ok=True)
157
+ api.upload_folder(folder_path=args.output_dir, repo_id=args.hub_model_id, repo_type="model")
158
+ print(f"Pushed to https://huggingface.co/{args.hub_model_id}")
159
+
160
+ if __name__ == '__main__':
161
+ main()