| |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .base import BaseModel |
| from .feature_extractor import FeatureExtractor |
|
|
|
|
| class MapEncoder(BaseModel): |
| default_conf = { |
| "embedding_dim": "???", |
| "output_dim": None, |
| "num_classes": "???", |
| "backbone": "???", |
| "unary_prior": False, |
| } |
|
|
| def _init(self, conf): |
| self.embeddings = torch.nn.ModuleDict( |
| { |
| k: torch.nn.Embedding(n + 1, conf.embedding_dim) |
| for k, n in conf.num_classes.items() |
| } |
| ) |
| |
| input_dim = len(conf.num_classes) * conf.embedding_dim |
| output_dim = conf.output_dim |
| if output_dim is None: |
| output_dim = conf.backbone.output_dim |
| if conf.unary_prior: |
| output_dim += 1 |
| if conf.backbone is None: |
| self.encoder = nn.Conv2d(input_dim, output_dim, 1) |
| elif conf.backbone == "simple": |
| self.encoder = nn.Sequential( |
| nn.Conv2d(input_dim, 128, 3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(128, 128, 3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(128, output_dim, 3, padding=1), |
| ) |
| else: |
| self.encoder = FeatureExtractor( |
| { |
| **conf.backbone, |
| "input_dim": input_dim, |
| "output_dim": output_dim, |
| } |
| ) |
|
|
| def _forward(self, data): |
| embeddings = [ |
| self.embeddings[k](data["map"][:, i]) |
| for i, k in enumerate(("areas", "ways", "nodes")) |
| ] |
| embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2) |
| if isinstance(self.encoder, BaseModel): |
| features = self.encoder({"image": embeddings})["feature_maps"] |
| else: |
| features = [self.encoder(embeddings)] |
| pred = {} |
| if self.conf.unary_prior: |
| pred["log_prior"] = [f[:, -1] for f in features] |
| features = [f[:, :-1] for f in features] |
| pred["map_features"] = features |
| return pred |
|
|