File size: 5,492 Bytes
9ed614d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, Callable
class MPPIPlanner:
def __init__(self, model, n_assets: int, horizon=20, n_samples=200, n_iterations=20,
temperature=0.005, action_std=0.2, signal_prior=0.33):
self.model = model
self.n_assets = n_assets
self.horizon = horizon
self.n_samples = n_samples
self.n_iterations = n_iterations
self.temperature = temperature
self.action_std = action_std
self.signal_prior = signal_prior
def plan(self, state_series, cost_fn=None, initial_weights=None):
device = state_series.device
B = 1
if initial_weights is None:
mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets
else:
mean_weights = initial_weights.unsqueeze(0).expand(B, self.horizon, self.n_assets)
std_weights = torch.ones_like(mean_weights) * self.action_std
mean_signals = torch.zeros(B, self.horizon, self.n_assets, device=device, dtype=torch.long)
for it in range(self.n_iterations):
noise = torch.randn(B * self.n_samples, self.horizon, self.n_assets, device=device)
sampled_weights = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise
sampled_weights = F.softmax(sampled_weights, dim=-1)
sampled_signals = torch.randint(0, 3, (B * self.n_samples, self.horizon, self.n_assets), device=device)
costs = torch.zeros(B * self.n_samples, device=device)
current_state = state_series.repeat(self.n_samples, 1, 1)
for t in range(self.horizon):
w_t = sampled_weights[:, t]
s_t = sampled_signals[:, t]
z_next = self.model.predict_next_state(current_state, w_t, s_t)
if cost_fn is not None:
c_t = cost_fn(z_next)
else:
c_t = -z_next.mean(dim=(1, 2))
costs += c_t
costs = costs.reshape(B, self.n_samples)
beta = costs.min(dim=1, keepdim=True).values
weights_mppi = torch.exp(-(costs - beta) / self.temperature)
weights_mppi = weights_mppi / weights_mppi.sum(dim=1, keepdim=True)
flat_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets)
mean_weights = (weights_mppi[:, :, None, None] * flat_weights).sum(dim=1)
std_weights = torch.sqrt((weights_mppi[:, :, None, None] * (flat_weights - mean_weights.unsqueeze(1)) ** 2).sum(dim=1) + 1e-4)
best_idx = costs.argmin(dim=1)
best_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0]
best_signals = sampled_signals.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0]
return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values}
class CEMPlanner:
def __init__(self, model, n_assets: int, horizon=20, n_candidates=200, n_elites=20, n_iterations=5, action_std=0.3):
self.model = model
self.n_assets = n_assets
self.horizon = horizon
self.n_candidates = n_candidates
self.n_elites = n_elites
self.n_iterations = n_iterations
self.action_std = action_std
def plan(self, state_series, cost_fn=None):
device = state_series.device
B = state_series.size(0)
mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets
std_weights = torch.ones_like(mean_weights) * self.action_std
for it in range(self.n_iterations):
noise = torch.randn(B, self.n_candidates, self.horizon, self.n_assets, device=device)
candidates = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise
candidates = F.softmax(candidates, dim=-1)
signals = torch.randint(0, 3, (B, self.n_candidates, self.horizon, self.n_assets), device=device)
costs = torch.zeros(B, self.n_candidates, device=device)
for t in range(self.horizon):
w_t = candidates[:, :, t].reshape(-1, self.n_assets)
s_t = signals[:, :, t].reshape(-1, self.n_assets)
state_rep = state_series.unsqueeze(1).repeat(1, self.n_candidates, 1, 1).reshape(-1, state_series.size(1), state_series.size(2))
z_next = self.model.predict_next_state(state_rep, w_t, s_t)
if cost_fn is not None:
c = cost_fn(z_next)
else:
c = -z_next.mean(dim=(1, 2))
costs += c.reshape(B, self.n_candidates)
_, elite_idx = torch.topk(costs, self.n_elites, dim=1, largest=False)
elite_weights = torch.gather(candidates, 1,
elite_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.horizon, self.n_assets))
mean_weights = elite_weights.mean(dim=1)
std_weights = elite_weights.std(dim=1) + 1e-4
best_idx = costs.argmin(dim=1)
best_weights = candidates[torch.arange(B), best_idx, 0]
best_signals = signals[torch.arange(B), best_idx, 0]
return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values}
def sharpe_cost(z_pred, target_return=0.0):
returns = z_pred.mean(dim=(1, 2))
return -returns
|