ashesh8500 commited on
Commit
e474e94
·
verified ·
1 Parent(s): 09dd12e

Upload run_training_fast.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_training_fast.py +59 -0
run_training_fast.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, '/app')
3
+ import torch
4
+ from finjepa.model import FinJEPA, FinJEPALoss
5
+ from finjepa.data import generate_synthetic_data, build_dataloaders
6
+ from torch.optim import AdamW
7
+ from finjepa.planner import CEMPlanner
8
+
9
+ DEVICE = 'cpu'
10
+ df = generate_synthetic_data(n_timesteps=5000, n_assets=1)
11
+ loaders = build_dataloaders(df, n_assets=1, context_window=60, target_window=10, batch_size=64)
12
+
13
+ model = FinJEPA(
14
+ in_features=14, n_assets=1, patch_size=4, embed_dim=64,
15
+ encoder_depth=2, encoder_heads=4, predictor_depth=3, predictor_heads=4,
16
+ ema_decay=0.996, use_idm=True,
17
+ ).to(DEVICE)
18
+
19
+ loss_fn = FinJEPALoss(pred_loss='l1', alpha=2.0, beta=1.0, delta=4.0, omega=0.5, gamma=0.5).to(DEVICE)
20
+ optimizer = AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.995), weight_decay=1e-5)
21
+
22
+ print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
23
+
24
+ for epoch in range(3):
25
+ model.train()
26
+ total_loss = 0
27
+ n = 0
28
+ for batch in loaders['train']:
29
+ ctx = batch['context'].to(DEVICE)
30
+ tgt = batch['target'].to(DEVICE)
31
+ w = batch['weights'].to(DEVICE)
32
+ s = batch['signals'].to(DEVICE)
33
+ h = batch['hedge'].to(DEVICE)
34
+ out = model(ctx, tgt, w, s, h)
35
+ actions_gt = {'weights': w, 'signals': s}
36
+ loss_dict = loss_fn(out, actions_gt)
37
+ loss = loss_dict['loss']
38
+ optimizer.zero_grad()
39
+ loss.backward()
40
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
41
+ optimizer.step()
42
+ model.update_target()
43
+ total_loss += loss.item()
44
+ n += 1
45
+ print(f"Epoch {epoch+1}/3 | loss={total_loss / n:.3f}")
46
+
47
+ print("Training complete. Testing planner...")
48
+ planner = CEMPlanner(model, n_assets=1, horizon=5, n_candidates=20, n_elites=5, n_iterations=2)
49
+ batch = next(iter(loaders['test']))
50
+ ctx = batch['context'][0:1].to(DEVICE)
51
+ result = planner.plan(ctx)
52
+ print(f"Best weights: {result['weights'].detach().cpu().numpy().round(3)}")
53
+ print(f"Best signals: {result['signals'].detach().cpu().numpy()}")
54
+ print(f"Expected cost: {result['expected_cost'].item():.4f}")
55
+
56
+ import os
57
+ os.makedirs('/app/finjepa/outputs', exist_ok=True)
58
+ torch.save(model.state_dict(), '/app/finjepa/outputs/fast_model.pt')
59
+ print("Model saved.")