| import torch |
|
|
| from .nets.attention_model.attention_model import * |
|
|
|
|
| class Problem: |
| def __init__(self, name): |
| self.NAME = name |
|
|
|
|
| class Backbone(nn.Module): |
| def __init__( |
| self, |
| embedding_dim=128, |
| problem_name="tsp", |
| n_encode_layers=3, |
| tanh_clipping=10.0, |
| n_heads=8, |
| device="cpu", |
| ): |
| super(Backbone, self).__init__() |
| self.device = device |
| self.problem = Problem(problem_name) |
| self.embedding = AutoEmbedding(self.problem.NAME, {"embedding_dim": embedding_dim}) |
|
|
| self.encoder = GraphAttentionEncoder( |
| n_heads=n_heads, |
| embed_dim=embedding_dim, |
| n_layers=n_encode_layers, |
| ) |
|
|
| self.decoder = Decoder( |
| embedding_dim, self.embedding.context_dim, n_heads, self.problem, tanh_clipping |
| ) |
|
|
| def forward(self, obs): |
| state = stateWrapper(obs, device=self.device, problem=self.problem.NAME) |
| input = state.states["observations"] |
| embedding = self.embedding(input) |
| encoded_inputs, _ = self.encoder(embedding) |
|
|
| |
| cached_embeddings = self.decoder._precompute(encoded_inputs) |
| logits, glimpse = self.decoder.advance(cached_embeddings, state) |
|
|
| return logits, glimpse |
|
|
| def encode(self, obs): |
| state = stateWrapper(obs, device=self.device, problem=self.problem.NAME) |
| input = state.states["observations"] |
| embedding = self.embedding(input) |
| encoded_inputs, _ = self.encoder(embedding) |
| cached_embeddings = self.decoder._precompute(encoded_inputs) |
| return cached_embeddings |
|
|
| def decode(self, obs, cached_embeddings): |
| state = stateWrapper(obs, device=self.device, problem=self.problem.NAME) |
| logits, glimpse = self.decoder.advance(cached_embeddings, state) |
|
|
| return logits, glimpse |
|
|
|
|
| class Actor(nn.Module): |
| def __init__(self): |
| super(Actor, self).__init__() |
|
|
| def forward(self, x): |
| logits = x[0] |
| return logits |
|
|
|
|
| class Critic(nn.Module): |
| def __init__(self, *args, **kwargs): |
| super(Critic, self).__init__() |
| hidden_size = kwargs["hidden_size"] |
| self.mlp = nn.Sequential( |
| nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1) |
| ) |
|
|
| def forward(self, x): |
| out = self.mlp(x[1]) |
| return out |
|
|
|
|
| class Agent(nn.Module): |
| def __init__(self, embedding_dim=128, device="cpu", name="tsp"): |
| super().__init__() |
| self.backbone = Backbone(embedding_dim=embedding_dim, device=device, problem_name=name) |
| self.critic = Critic(hidden_size=embedding_dim) |
| self.actor = Actor() |
|
|
| def forward(self, x): |
| x = self.backbone(x) |
| logits = self.actor(x) |
| action = logits.max(2)[1] |
| return action, logits |
|
|
| def get_value(self, x): |
| x = self.backbone(x) |
| return self.critic(x) |
|
|
| def get_action_and_value(self, x, action=None): |
| x = self.backbone(x) |
| logits = self.actor(x) |
| probs = torch.distributions.Categorical(logits=logits) |
| if action is None: |
| action = probs.sample() |
| return action, probs.log_prob(action), probs.entropy(), self.critic(x) |
|
|
| def get_value_cached(self, x, state): |
| x = self.backbone.decode(x, state) |
| return self.critic(x) |
|
|
| def get_action_and_value_cached(self, x, action=None, state=None): |
| if state is None: |
| state = self.backbone.encode(x) |
| x = self.backbone.decode(x, state) |
| else: |
| x = self.backbone.decode(x, state) |
| logits = self.actor(x) |
| probs = torch.distributions.Categorical(logits=logits) |
| if action is None: |
| action = probs.sample() |
| return action, probs.log_prob(action), probs.entropy(), self.critic(x), state |
|
|
|
|
| class stateWrapper: |
| """ |
| from dict of numpy arrays to an object that supplies function and data |
| """ |
|
|
| def __init__(self, states, device, problem="tsp"): |
| self.device = device |
| self.states = {k: torch.tensor(v, device=self.device) for k, v in states.items()} |
| if problem == "tsp": |
| self.is_initial_action = self.states["is_initial_action"].to(torch.bool) |
| self.first_a = self.states["first_node_idx"] |
| elif problem == "cvrp": |
| input = { |
| "loc": self.states["observations"], |
| "depot": self.states["depot"].squeeze(-1), |
| "demand": self.states["demand"], |
| } |
| self.states["observations"] = input |
| self.VEHICLE_CAPACITY = 0 |
| self.used_capacity = -self.states["current_load"] |
|
|
| def get_current_node(self): |
| return self.states["last_node_idx"] |
|
|
| def get_mask(self): |
| return (1 - self.states["action_mask"]).to(torch.bool) |
|
|