File size: 5,492 Bytes
9ed614d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, Callable

class MPPIPlanner:
    def __init__(self, model, n_assets: int, horizon=20, n_samples=200, n_iterations=20,
                 temperature=0.005, action_std=0.2, signal_prior=0.33):
        self.model = model
        self.n_assets = n_assets
        self.horizon = horizon
        self.n_samples = n_samples
        self.n_iterations = n_iterations
        self.temperature = temperature
        self.action_std = action_std
        self.signal_prior = signal_prior

    def plan(self, state_series, cost_fn=None, initial_weights=None):
        device = state_series.device
        B = 1
        if initial_weights is None:
            mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets
        else:
            mean_weights = initial_weights.unsqueeze(0).expand(B, self.horizon, self.n_assets)
        std_weights = torch.ones_like(mean_weights) * self.action_std
        mean_signals = torch.zeros(B, self.horizon, self.n_assets, device=device, dtype=torch.long)
        for it in range(self.n_iterations):
            noise = torch.randn(B * self.n_samples, self.horizon, self.n_assets, device=device)
            sampled_weights = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise
            sampled_weights = F.softmax(sampled_weights, dim=-1)
            sampled_signals = torch.randint(0, 3, (B * self.n_samples, self.horizon, self.n_assets), device=device)
            costs = torch.zeros(B * self.n_samples, device=device)
            current_state = state_series.repeat(self.n_samples, 1, 1)
            for t in range(self.horizon):
                w_t = sampled_weights[:, t]
                s_t = sampled_signals[:, t]
                z_next = self.model.predict_next_state(current_state, w_t, s_t)
                if cost_fn is not None:
                    c_t = cost_fn(z_next)
                else:
                    c_t = -z_next.mean(dim=(1, 2))
                costs += c_t
            costs = costs.reshape(B, self.n_samples)
            beta = costs.min(dim=1, keepdim=True).values
            weights_mppi = torch.exp(-(costs - beta) / self.temperature)
            weights_mppi = weights_mppi / weights_mppi.sum(dim=1, keepdim=True)
            flat_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets)
            mean_weights = (weights_mppi[:, :, None, None] * flat_weights).sum(dim=1)
            std_weights = torch.sqrt((weights_mppi[:, :, None, None] * (flat_weights - mean_weights.unsqueeze(1)) ** 2).sum(dim=1) + 1e-4)
        best_idx = costs.argmin(dim=1)
        best_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0]
        best_signals = sampled_signals.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0]
        return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values}

class CEMPlanner:
    def __init__(self, model, n_assets: int, horizon=20, n_candidates=200, n_elites=20, n_iterations=5, action_std=0.3):
        self.model = model
        self.n_assets = n_assets
        self.horizon = horizon
        self.n_candidates = n_candidates
        self.n_elites = n_elites
        self.n_iterations = n_iterations
        self.action_std = action_std

    def plan(self, state_series, cost_fn=None):
        device = state_series.device
        B = state_series.size(0)
        mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets
        std_weights = torch.ones_like(mean_weights) * self.action_std
        for it in range(self.n_iterations):
            noise = torch.randn(B, self.n_candidates, self.horizon, self.n_assets, device=device)
            candidates = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise
            candidates = F.softmax(candidates, dim=-1)
            signals = torch.randint(0, 3, (B, self.n_candidates, self.horizon, self.n_assets), device=device)
            costs = torch.zeros(B, self.n_candidates, device=device)
            for t in range(self.horizon):
                w_t = candidates[:, :, t].reshape(-1, self.n_assets)
                s_t = signals[:, :, t].reshape(-1, self.n_assets)
                state_rep = state_series.unsqueeze(1).repeat(1, self.n_candidates, 1, 1).reshape(-1, state_series.size(1), state_series.size(2))
                z_next = self.model.predict_next_state(state_rep, w_t, s_t)
                if cost_fn is not None:
                    c = cost_fn(z_next)
                else:
                    c = -z_next.mean(dim=(1, 2))
                costs += c.reshape(B, self.n_candidates)
            _, elite_idx = torch.topk(costs, self.n_elites, dim=1, largest=False)
            elite_weights = torch.gather(candidates, 1,
                elite_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.horizon, self.n_assets))
            mean_weights = elite_weights.mean(dim=1)
            std_weights = elite_weights.std(dim=1) + 1e-4
        best_idx = costs.argmin(dim=1)
        best_weights = candidates[torch.arange(B), best_idx, 0]
        best_signals = signals[torch.arange(B), best_idx, 0]
        return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values}

def sharpe_cost(z_pred, target_return=0.0):
    returns = z_pred.mean(dim=(1, 2))
    return -returns