| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Optional, Callable |
|
|
| class MPPIPlanner: |
| def __init__(self, model, n_assets: int, horizon=20, n_samples=200, n_iterations=20, |
| temperature=0.005, action_std=0.2, signal_prior=0.33): |
| self.model = model |
| self.n_assets = n_assets |
| self.horizon = horizon |
| self.n_samples = n_samples |
| self.n_iterations = n_iterations |
| self.temperature = temperature |
| self.action_std = action_std |
| self.signal_prior = signal_prior |
|
|
| def plan(self, state_series, cost_fn=None, initial_weights=None): |
| device = state_series.device |
| B = 1 |
| if initial_weights is None: |
| mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets |
| else: |
| mean_weights = initial_weights.unsqueeze(0).expand(B, self.horizon, self.n_assets) |
| std_weights = torch.ones_like(mean_weights) * self.action_std |
| mean_signals = torch.zeros(B, self.horizon, self.n_assets, device=device, dtype=torch.long) |
| for it in range(self.n_iterations): |
| noise = torch.randn(B * self.n_samples, self.horizon, self.n_assets, device=device) |
| sampled_weights = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise |
| sampled_weights = F.softmax(sampled_weights, dim=-1) |
| sampled_signals = torch.randint(0, 3, (B * self.n_samples, self.horizon, self.n_assets), device=device) |
| costs = torch.zeros(B * self.n_samples, device=device) |
| current_state = state_series.repeat(self.n_samples, 1, 1) |
| for t in range(self.horizon): |
| w_t = sampled_weights[:, t] |
| s_t = sampled_signals[:, t] |
| z_next = self.model.predict_next_state(current_state, w_t, s_t) |
| if cost_fn is not None: |
| c_t = cost_fn(z_next) |
| else: |
| c_t = -z_next.mean(dim=(1, 2)) |
| costs += c_t |
| costs = costs.reshape(B, self.n_samples) |
| beta = costs.min(dim=1, keepdim=True).values |
| weights_mppi = torch.exp(-(costs - beta) / self.temperature) |
| weights_mppi = weights_mppi / weights_mppi.sum(dim=1, keepdim=True) |
| flat_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets) |
| mean_weights = (weights_mppi[:, :, None, None] * flat_weights).sum(dim=1) |
| std_weights = torch.sqrt((weights_mppi[:, :, None, None] * (flat_weights - mean_weights.unsqueeze(1)) ** 2).sum(dim=1) + 1e-4) |
| best_idx = costs.argmin(dim=1) |
| best_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0] |
| best_signals = sampled_signals.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0] |
| return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values} |
|
|
| class CEMPlanner: |
| def __init__(self, model, n_assets: int, horizon=20, n_candidates=200, n_elites=20, n_iterations=5, action_std=0.3): |
| self.model = model |
| self.n_assets = n_assets |
| self.horizon = horizon |
| self.n_candidates = n_candidates |
| self.n_elites = n_elites |
| self.n_iterations = n_iterations |
| self.action_std = action_std |
|
|
| def plan(self, state_series, cost_fn=None): |
| device = state_series.device |
| B = state_series.size(0) |
| mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets |
| std_weights = torch.ones_like(mean_weights) * self.action_std |
| for it in range(self.n_iterations): |
| noise = torch.randn(B, self.n_candidates, self.horizon, self.n_assets, device=device) |
| candidates = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise |
| candidates = F.softmax(candidates, dim=-1) |
| signals = torch.randint(0, 3, (B, self.n_candidates, self.horizon, self.n_assets), device=device) |
| costs = torch.zeros(B, self.n_candidates, device=device) |
| for t in range(self.horizon): |
| w_t = candidates[:, :, t].reshape(-1, self.n_assets) |
| s_t = signals[:, :, t].reshape(-1, self.n_assets) |
| state_rep = state_series.unsqueeze(1).repeat(1, self.n_candidates, 1, 1).reshape(-1, state_series.size(1), state_series.size(2)) |
| z_next = self.model.predict_next_state(state_rep, w_t, s_t) |
| if cost_fn is not None: |
| c = cost_fn(z_next) |
| else: |
| c = -z_next.mean(dim=(1, 2)) |
| costs += c.reshape(B, self.n_candidates) |
| _, elite_idx = torch.topk(costs, self.n_elites, dim=1, largest=False) |
| elite_weights = torch.gather(candidates, 1, |
| elite_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.horizon, self.n_assets)) |
| mean_weights = elite_weights.mean(dim=1) |
| std_weights = elite_weights.std(dim=1) + 1e-4 |
| best_idx = costs.argmin(dim=1) |
| best_weights = candidates[torch.arange(B), best_idx, 0] |
| best_signals = signals[torch.arange(B), best_idx, 0] |
| return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values} |
|
|
| def sharpe_cost(z_pred, target_return=0.0): |
| returns = z_pred.mean(dim=(1, 2)) |
| return -returns |
|
|