| import torch |
| import torch.nn as nn |
| import numpy as np |
| from methods.meta_template import MetaTemplate |
| from methods.gnn import GNN_nl |
| from methods import backbone |
|
|
| class GnnNet(MetaTemplate): |
| maml=False |
| def __init__(self, model_func, n_way, n_support, tf_path=None): |
| super(GnnNet, self).__init__(model_func, n_way, n_support, tf_path=tf_path) |
|
|
| |
| self.loss_fn = nn.CrossEntropyLoss() |
|
|
| |
| self.fc = nn.Sequential(nn.Linear(self.feat_dim, 128), nn.BatchNorm1d(128, track_running_stats=False)) if not self.maml else nn.Sequential(backbone.Linear_fw(self.feat_dim, 128), backbone.BatchNorm1d_fw(128, track_running_stats=False)) |
| self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way) |
| self.method = 'GnnNet' |
|
|
| |
| support_label = torch.from_numpy(np.repeat(range(self.n_way), self.n_support)).unsqueeze(1) |
| support_label = torch.zeros(self.n_way*self.n_support, self.n_way).scatter(1, support_label, 1).view(self.n_way, self.n_support, self.n_way) |
| support_label = torch.cat([support_label, torch.zeros(self.n_way, 1, n_way)], dim=1) |
| self.support_label = support_label.view(1, -1, self.n_way) |
|
|
| def cuda(self): |
| self.feature.cuda() |
| self.fc.cuda() |
| self.gnn.cuda() |
| self.support_label = self.support_label.cuda() |
| return self |
|
|
| def set_forward(self,x,is_feature=False): |
| x = x.cuda() |
|
|
| if is_feature: |
| |
| assert(x.size(1) == self.n_support + 15) |
| z = self.fc(x.view(-1, *x.size()[2:])) |
| z = z.view(self.n_way, -1, z.size(1)) |
| else: |
| |
| x = x.view(-1, *x.size()[2:]) |
| z = self.fc(self.feature(x)) |
| z = z.view(self.n_way, -1, z.size(1)) |
| |
| |
| z_stack = [torch.cat([z[:, :self.n_support], z[:, self.n_support + i:self.n_support + i + 1]], dim=1).view(1, -1, z.size(2)) for i in range(self.n_query)] |
| assert(z_stack[0].size(1) == self.n_way*(self.n_support + 1)) |
| |
| scores = self.forward_gnn(z_stack) |
| return scores |
|
|
| def forward_gnn(self, zs): |
| |
| nodes = torch.cat([torch.cat([z, self.support_label], dim=2) for z in zs], dim=0) |
| |
| scores = self.gnn(nodes) |
|
|
| |
| scores = scores.view(self.n_query, self.n_way, self.n_support + 1, self.n_way)[:, :, -1].permute(1, 0, 2).contiguous().view(-1, self.n_way) |
| return scores |
|
|
| def set_forward_loss(self, x): |
| |
| |
| y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query)) |
| |
| y_query = y_query.cuda() |
| scores = self.set_forward(x) |
| |
| loss = self.loss_fn(scores, y_query) |
| |
| return scores, loss |
|
|