import torch import torch.nn.functional as F import numpy as np from typing import Optional, Callable class MPPIPlanner: def __init__(self, model, n_assets: int, horizon=20, n_samples=200, n_iterations=20, temperature=0.005, action_std=0.2, signal_prior=0.33): self.model = model self.n_assets = n_assets self.horizon = horizon self.n_samples = n_samples self.n_iterations = n_iterations self.temperature = temperature self.action_std = action_std self.signal_prior = signal_prior def plan(self, state_series, cost_fn=None, initial_weights=None): device = state_series.device B = 1 if initial_weights is None: mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets else: mean_weights = initial_weights.unsqueeze(0).expand(B, self.horizon, self.n_assets) std_weights = torch.ones_like(mean_weights) * self.action_std mean_signals = torch.zeros(B, self.horizon, self.n_assets, device=device, dtype=torch.long) for it in range(self.n_iterations): noise = torch.randn(B * self.n_samples, self.horizon, self.n_assets, device=device) sampled_weights = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise sampled_weights = F.softmax(sampled_weights, dim=-1) sampled_signals = torch.randint(0, 3, (B * self.n_samples, self.horizon, self.n_assets), device=device) costs = torch.zeros(B * self.n_samples, device=device) current_state = state_series.repeat(self.n_samples, 1, 1) for t in range(self.horizon): w_t = sampled_weights[:, t] s_t = sampled_signals[:, t] z_next = self.model.predict_next_state(current_state, w_t, s_t) if cost_fn is not None: c_t = cost_fn(z_next) else: c_t = -z_next.mean(dim=(1, 2)) costs += c_t costs = costs.reshape(B, self.n_samples) beta = costs.min(dim=1, keepdim=True).values weights_mppi = torch.exp(-(costs - beta) / self.temperature) weights_mppi = weights_mppi / weights_mppi.sum(dim=1, keepdim=True) flat_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets) mean_weights = (weights_mppi[:, :, None, None] * flat_weights).sum(dim=1) std_weights = torch.sqrt((weights_mppi[:, :, None, None] * (flat_weights - mean_weights.unsqueeze(1)) ** 2).sum(dim=1) + 1e-4) best_idx = costs.argmin(dim=1) best_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0] best_signals = sampled_signals.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0] return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values} class CEMPlanner: def __init__(self, model, n_assets: int, horizon=20, n_candidates=200, n_elites=20, n_iterations=5, action_std=0.3): self.model = model self.n_assets = n_assets self.horizon = horizon self.n_candidates = n_candidates self.n_elites = n_elites self.n_iterations = n_iterations self.action_std = action_std def plan(self, state_series, cost_fn=None): device = state_series.device B = state_series.size(0) mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets std_weights = torch.ones_like(mean_weights) * self.action_std for it in range(self.n_iterations): noise = torch.randn(B, self.n_candidates, self.horizon, self.n_assets, device=device) candidates = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise candidates = F.softmax(candidates, dim=-1) signals = torch.randint(0, 3, (B, self.n_candidates, self.horizon, self.n_assets), device=device) costs = torch.zeros(B, self.n_candidates, device=device) for t in range(self.horizon): w_t = candidates[:, :, t].reshape(-1, self.n_assets) s_t = signals[:, :, t].reshape(-1, self.n_assets) state_rep = state_series.unsqueeze(1).repeat(1, self.n_candidates, 1, 1).reshape(-1, state_series.size(1), state_series.size(2)) z_next = self.model.predict_next_state(state_rep, w_t, s_t) if cost_fn is not None: c = cost_fn(z_next) else: c = -z_next.mean(dim=(1, 2)) costs += c.reshape(B, self.n_candidates) _, elite_idx = torch.topk(costs, self.n_elites, dim=1, largest=False) elite_weights = torch.gather(candidates, 1, elite_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.horizon, self.n_assets)) mean_weights = elite_weights.mean(dim=1) std_weights = elite_weights.std(dim=1) + 1e-4 best_idx = costs.argmin(dim=1) best_weights = candidates[torch.arange(B), best_idx, 0] best_signals = signals[torch.arange(B), best_idx, 0] return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values} def sharpe_cost(z_pred, target_return=0.0): returns = z_pred.mean(dim=(1, 2)) return -returns