ashesh8500 commited on
Commit
9ed614d
·
verified ·
1 Parent(s): d5789cd

Upload planner.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. planner.py +98 -0
planner.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from typing import Optional, Callable
5
+
6
+ class MPPIPlanner:
7
+ def __init__(self, model, n_assets: int, horizon=20, n_samples=200, n_iterations=20,
8
+ temperature=0.005, action_std=0.2, signal_prior=0.33):
9
+ self.model = model
10
+ self.n_assets = n_assets
11
+ self.horizon = horizon
12
+ self.n_samples = n_samples
13
+ self.n_iterations = n_iterations
14
+ self.temperature = temperature
15
+ self.action_std = action_std
16
+ self.signal_prior = signal_prior
17
+
18
+ def plan(self, state_series, cost_fn=None, initial_weights=None):
19
+ device = state_series.device
20
+ B = 1
21
+ if initial_weights is None:
22
+ mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets
23
+ else:
24
+ mean_weights = initial_weights.unsqueeze(0).expand(B, self.horizon, self.n_assets)
25
+ std_weights = torch.ones_like(mean_weights) * self.action_std
26
+ mean_signals = torch.zeros(B, self.horizon, self.n_assets, device=device, dtype=torch.long)
27
+ for it in range(self.n_iterations):
28
+ noise = torch.randn(B * self.n_samples, self.horizon, self.n_assets, device=device)
29
+ sampled_weights = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise
30
+ sampled_weights = F.softmax(sampled_weights, dim=-1)
31
+ sampled_signals = torch.randint(0, 3, (B * self.n_samples, self.horizon, self.n_assets), device=device)
32
+ costs = torch.zeros(B * self.n_samples, device=device)
33
+ current_state = state_series.repeat(self.n_samples, 1, 1)
34
+ for t in range(self.horizon):
35
+ w_t = sampled_weights[:, t]
36
+ s_t = sampled_signals[:, t]
37
+ z_next = self.model.predict_next_state(current_state, w_t, s_t)
38
+ if cost_fn is not None:
39
+ c_t = cost_fn(z_next)
40
+ else:
41
+ c_t = -z_next.mean(dim=(1, 2))
42
+ costs += c_t
43
+ costs = costs.reshape(B, self.n_samples)
44
+ beta = costs.min(dim=1, keepdim=True).values
45
+ weights_mppi = torch.exp(-(costs - beta) / self.temperature)
46
+ weights_mppi = weights_mppi / weights_mppi.sum(dim=1, keepdim=True)
47
+ flat_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets)
48
+ mean_weights = (weights_mppi[:, :, None, None] * flat_weights).sum(dim=1)
49
+ std_weights = torch.sqrt((weights_mppi[:, :, None, None] * (flat_weights - mean_weights.unsqueeze(1)) ** 2).sum(dim=1) + 1e-4)
50
+ best_idx = costs.argmin(dim=1)
51
+ best_weights = sampled_weights.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0]
52
+ best_signals = sampled_signals.reshape(B, self.n_samples, self.horizon, self.n_assets)[torch.arange(B), best_idx, 0]
53
+ return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values}
54
+
55
+ class CEMPlanner:
56
+ def __init__(self, model, n_assets: int, horizon=20, n_candidates=200, n_elites=20, n_iterations=5, action_std=0.3):
57
+ self.model = model
58
+ self.n_assets = n_assets
59
+ self.horizon = horizon
60
+ self.n_candidates = n_candidates
61
+ self.n_elites = n_elites
62
+ self.n_iterations = n_iterations
63
+ self.action_std = action_std
64
+
65
+ def plan(self, state_series, cost_fn=None):
66
+ device = state_series.device
67
+ B = state_series.size(0)
68
+ mean_weights = torch.ones(B, self.horizon, self.n_assets, device=device) / self.n_assets
69
+ std_weights = torch.ones_like(mean_weights) * self.action_std
70
+ for it in range(self.n_iterations):
71
+ noise = torch.randn(B, self.n_candidates, self.horizon, self.n_assets, device=device)
72
+ candidates = mean_weights.unsqueeze(1) + std_weights.unsqueeze(1) * noise
73
+ candidates = F.softmax(candidates, dim=-1)
74
+ signals = torch.randint(0, 3, (B, self.n_candidates, self.horizon, self.n_assets), device=device)
75
+ costs = torch.zeros(B, self.n_candidates, device=device)
76
+ for t in range(self.horizon):
77
+ w_t = candidates[:, :, t].reshape(-1, self.n_assets)
78
+ s_t = signals[:, :, t].reshape(-1, self.n_assets)
79
+ state_rep = state_series.unsqueeze(1).repeat(1, self.n_candidates, 1, 1).reshape(-1, state_series.size(1), state_series.size(2))
80
+ z_next = self.model.predict_next_state(state_rep, w_t, s_t)
81
+ if cost_fn is not None:
82
+ c = cost_fn(z_next)
83
+ else:
84
+ c = -z_next.mean(dim=(1, 2))
85
+ costs += c.reshape(B, self.n_candidates)
86
+ _, elite_idx = torch.topk(costs, self.n_elites, dim=1, largest=False)
87
+ elite_weights = torch.gather(candidates, 1,
88
+ elite_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.horizon, self.n_assets))
89
+ mean_weights = elite_weights.mean(dim=1)
90
+ std_weights = elite_weights.std(dim=1) + 1e-4
91
+ best_idx = costs.argmin(dim=1)
92
+ best_weights = candidates[torch.arange(B), best_idx, 0]
93
+ best_signals = signals[torch.arange(B), best_idx, 0]
94
+ return {"weights": best_weights, "signals": best_signals, "expected_cost": costs.min(dim=1).values}
95
+
96
+ def sharpe_cost(z_pred, target_return=0.0):
97
+ returns = z_pred.mean(dim=(1, 2))
98
+ return -returns