Upload run_training_fast.py with huggingface_hub
Browse files- run_training_fast.py +59 -0
run_training_fast.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.insert(0, '/app')
|
| 3 |
+
import torch
|
| 4 |
+
from finjepa.model import FinJEPA, FinJEPALoss
|
| 5 |
+
from finjepa.data import generate_synthetic_data, build_dataloaders
|
| 6 |
+
from torch.optim import AdamW
|
| 7 |
+
from finjepa.planner import CEMPlanner
|
| 8 |
+
|
| 9 |
+
DEVICE = 'cpu'
|
| 10 |
+
df = generate_synthetic_data(n_timesteps=5000, n_assets=1)
|
| 11 |
+
loaders = build_dataloaders(df, n_assets=1, context_window=60, target_window=10, batch_size=64)
|
| 12 |
+
|
| 13 |
+
model = FinJEPA(
|
| 14 |
+
in_features=14, n_assets=1, patch_size=4, embed_dim=64,
|
| 15 |
+
encoder_depth=2, encoder_heads=4, predictor_depth=3, predictor_heads=4,
|
| 16 |
+
ema_decay=0.996, use_idm=True,
|
| 17 |
+
).to(DEVICE)
|
| 18 |
+
|
| 19 |
+
loss_fn = FinJEPALoss(pred_loss='l1', alpha=2.0, beta=1.0, delta=4.0, omega=0.5, gamma=0.5).to(DEVICE)
|
| 20 |
+
optimizer = AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.995), weight_decay=1e-5)
|
| 21 |
+
|
| 22 |
+
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
| 23 |
+
|
| 24 |
+
for epoch in range(3):
|
| 25 |
+
model.train()
|
| 26 |
+
total_loss = 0
|
| 27 |
+
n = 0
|
| 28 |
+
for batch in loaders['train']:
|
| 29 |
+
ctx = batch['context'].to(DEVICE)
|
| 30 |
+
tgt = batch['target'].to(DEVICE)
|
| 31 |
+
w = batch['weights'].to(DEVICE)
|
| 32 |
+
s = batch['signals'].to(DEVICE)
|
| 33 |
+
h = batch['hedge'].to(DEVICE)
|
| 34 |
+
out = model(ctx, tgt, w, s, h)
|
| 35 |
+
actions_gt = {'weights': w, 'signals': s}
|
| 36 |
+
loss_dict = loss_fn(out, actions_gt)
|
| 37 |
+
loss = loss_dict['loss']
|
| 38 |
+
optimizer.zero_grad()
|
| 39 |
+
loss.backward()
|
| 40 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 41 |
+
optimizer.step()
|
| 42 |
+
model.update_target()
|
| 43 |
+
total_loss += loss.item()
|
| 44 |
+
n += 1
|
| 45 |
+
print(f"Epoch {epoch+1}/3 | loss={total_loss / n:.3f}")
|
| 46 |
+
|
| 47 |
+
print("Training complete. Testing planner...")
|
| 48 |
+
planner = CEMPlanner(model, n_assets=1, horizon=5, n_candidates=20, n_elites=5, n_iterations=2)
|
| 49 |
+
batch = next(iter(loaders['test']))
|
| 50 |
+
ctx = batch['context'][0:1].to(DEVICE)
|
| 51 |
+
result = planner.plan(ctx)
|
| 52 |
+
print(f"Best weights: {result['weights'].detach().cpu().numpy().round(3)}")
|
| 53 |
+
print(f"Best signals: {result['signals'].detach().cpu().numpy()}")
|
| 54 |
+
print(f"Expected cost: {result['expected_cost'].item():.4f}")
|
| 55 |
+
|
| 56 |
+
import os
|
| 57 |
+
os.makedirs('/app/finjepa/outputs', exist_ok=True)
|
| 58 |
+
torch.save(model.state_dict(), '/app/finjepa/outputs/fast_model.pt')
|
| 59 |
+
print("Model saved.")
|