| import sys |
| sys.path.insert(0, '/app') |
| import torch |
| from finjepa.model import FinJEPA, FinJEPALoss |
| from finjepa.data import generate_synthetic_data, build_dataloaders |
| from torch.optim import AdamW |
| from finjepa.planner import CEMPlanner |
|
|
| DEVICE = 'cpu' |
| df = generate_synthetic_data(n_timesteps=5000, n_assets=1) |
| loaders = build_dataloaders(df, n_assets=1, context_window=60, target_window=10, batch_size=64) |
|
|
| model = FinJEPA( |
| in_features=14, n_assets=1, patch_size=4, embed_dim=64, |
| encoder_depth=2, encoder_heads=4, predictor_depth=3, predictor_heads=4, |
| ema_decay=0.996, use_idm=True, |
| ).to(DEVICE) |
|
|
| loss_fn = FinJEPALoss(pred_loss='l1', alpha=2.0, beta=1.0, delta=4.0, omega=0.5, gamma=0.5).to(DEVICE) |
| optimizer = AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.995), weight_decay=1e-5) |
|
|
| print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") |
|
|
| for epoch in range(3): |
| model.train() |
| total_loss = 0 |
| n = 0 |
| for batch in loaders['train']: |
| ctx = batch['context'].to(DEVICE) |
| tgt = batch['target'].to(DEVICE) |
| w = batch['weights'].to(DEVICE) |
| s = batch['signals'].to(DEVICE) |
| h = batch['hedge'].to(DEVICE) |
| out = model(ctx, tgt, w, s, h) |
| actions_gt = {'weights': w, 'signals': s} |
| loss_dict = loss_fn(out, actions_gt) |
| loss = loss_dict['loss'] |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| model.update_target() |
| total_loss += loss.item() |
| n += 1 |
| print(f"Epoch {epoch+1}/3 | loss={total_loss / n:.3f}") |
|
|
| print("Training complete. Testing planner...") |
| planner = CEMPlanner(model, n_assets=1, horizon=5, n_candidates=20, n_elites=5, n_iterations=2) |
| batch = next(iter(loaders['test'])) |
| ctx = batch['context'][0:1].to(DEVICE) |
| result = planner.plan(ctx) |
| print(f"Best weights: {result['weights'].detach().cpu().numpy().round(3)}") |
| print(f"Best signals: {result['signals'].detach().cpu().numpy()}") |
| print(f"Expected cost: {result['expected_cost'].item():.4f}") |
|
|
| import os |
| os.makedirs('/app/finjepa/outputs', exist_ok=True) |
| torch.save(model.state_dict(), '/app/finjepa/outputs/fast_model.pt') |
| print("Model saved.") |
|
|