File size: 2,254 Bytes
e474e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import sys
sys.path.insert(0, '/app')
import torch
from finjepa.model import FinJEPA, FinJEPALoss
from finjepa.data import generate_synthetic_data, build_dataloaders
from torch.optim import AdamW
from finjepa.planner import CEMPlanner

DEVICE = 'cpu'
df = generate_synthetic_data(n_timesteps=5000, n_assets=1)
loaders = build_dataloaders(df, n_assets=1, context_window=60, target_window=10, batch_size=64)

model = FinJEPA(
    in_features=14, n_assets=1, patch_size=4, embed_dim=64,
    encoder_depth=2, encoder_heads=4, predictor_depth=3, predictor_heads=4,
    ema_decay=0.996, use_idm=True,
).to(DEVICE)

loss_fn = FinJEPALoss(pred_loss='l1', alpha=2.0, beta=1.0, delta=4.0, omega=0.5, gamma=0.5).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.995), weight_decay=1e-5)

print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

for epoch in range(3):
    model.train()
    total_loss = 0
    n = 0
    for batch in loaders['train']:
        ctx = batch['context'].to(DEVICE)
        tgt = batch['target'].to(DEVICE)
        w = batch['weights'].to(DEVICE)
        s = batch['signals'].to(DEVICE)
        h = batch['hedge'].to(DEVICE)
        out = model(ctx, tgt, w, s, h)
        actions_gt = {'weights': w, 'signals': s}
        loss_dict = loss_fn(out, actions_gt)
        loss = loss_dict['loss']
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        model.update_target()
        total_loss += loss.item()
        n += 1
    print(f"Epoch {epoch+1}/3 | loss={total_loss / n:.3f}")

print("Training complete. Testing planner...")
planner = CEMPlanner(model, n_assets=1, horizon=5, n_candidates=20, n_elites=5, n_iterations=2)
batch = next(iter(loaders['test']))
ctx = batch['context'][0:1].to(DEVICE)
result = planner.plan(ctx)
print(f"Best weights: {result['weights'].detach().cpu().numpy().round(3)}")
print(f"Best signals: {result['signals'].detach().cpu().numpy()}")
print(f"Expected cost: {result['expected_cost'].item():.4f}")

import os
os.makedirs('/app/finjepa/outputs', exist_ok=True)
torch.save(model.state_dict(), '/app/finjepa/outputs/fast_model.pt')
print("Model saved.")