| import math |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.autograd import Variable |
| from graph import Graph |
| import pytorch_lightning as pl |
| from torchmetrics.classification import MulticlassAccuracy, BinaryAccuracy |
| import torch.optim as optim |
|
|
| def import_class(name): |
| components = name.split('.') |
| mod = __import__(components[0]) |
| for comp in components[1:]: |
| mod = getattr(mod, comp) |
| return mod |
|
|
|
|
| def conv_branch_init(conv, branches): |
| weight = conv.weight |
| n = weight.size(0) |
| k1 = weight.size(1) |
| k2 = weight.size(2) |
| nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches))) |
| nn.init.constant_(conv.bias, 0) |
|
|
|
|
| def conv_init(conv): |
| nn.init.kaiming_normal_(conv.weight, mode='fan_out') |
| nn.init.constant_(conv.bias, 0) |
|
|
|
|
| def bn_init(bn, scale): |
| nn.init.constant_(bn.weight, scale) |
| nn.init.constant_(bn.bias, 0) |
|
|
|
|
| class unit_tcn(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size=9, stride=1): |
| super(unit_tcn, self).__init__() |
| pad = int((kernel_size - 1) / 2) |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0), |
| stride=(stride, 1)) |
|
|
| self.bn = nn.BatchNorm2d(out_channels) |
| self.relu = nn.ReLU(inplace=True) |
| conv_init(self.conv) |
| bn_init(self.bn, 1) |
|
|
| def forward(self, x): |
| x = self.bn(self.conv(x)) |
| return x |
|
|
|
|
| class unit_gcn(nn.Module): |
| def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3, adaptive=True, attention=True): |
| super(unit_gcn, self).__init__() |
| inter_channels = out_channels // coff_embedding |
| self.inter_c = inter_channels |
| self.out_c = out_channels |
| self.in_c = in_channels |
| self.num_subset = num_subset |
| num_jpts = A.shape[-1] |
|
|
| self.conv_d = nn.ModuleList() |
| for i in range(self.num_subset): |
| self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1)) |
|
|
| if adaptive: |
| self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32))) |
| self.alpha = nn.Parameter(torch.zeros(1)) |
| |
| |
| |
| |
| self.conv_a = nn.ModuleList() |
| self.conv_b = nn.ModuleList() |
| for i in range(self.num_subset): |
| self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1)) |
| self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1)) |
| else: |
| self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False) |
| self.adaptive = adaptive |
|
|
| if attention: |
| |
| |
| |
| |
|
|
| |
| self.conv_ta = nn.Conv1d(out_channels, 1, 9, padding=4) |
| nn.init.constant_(self.conv_ta.weight, 0) |
| nn.init.constant_(self.conv_ta.bias, 0) |
|
|
| |
| ker_jpt = num_jpts - 1 if not num_jpts % 2 else num_jpts |
| pad = (ker_jpt - 1) // 2 |
| self.conv_sa = nn.Conv1d(out_channels, 1, ker_jpt, padding=pad) |
| nn.init.xavier_normal_(self.conv_sa.weight) |
| nn.init.constant_(self.conv_sa.bias, 0) |
|
|
| |
| rr = 2 |
| self.fc1c = nn.Linear(out_channels, out_channels // rr) |
| self.fc2c = nn.Linear(out_channels // rr, out_channels) |
| nn.init.kaiming_normal_(self.fc1c.weight) |
| nn.init.constant_(self.fc1c.bias, 0) |
| nn.init.constant_(self.fc2c.weight, 0) |
| nn.init.constant_(self.fc2c.bias, 0) |
|
|
| |
| |
| self.attention = attention |
|
|
| if in_channels != out_channels: |
| self.down = nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, 1), |
| nn.BatchNorm2d(out_channels) |
| ) |
| else: |
| self.down = lambda x: x |
|
|
| self.bn = nn.BatchNorm2d(out_channels) |
| self.soft = nn.Softmax(-2) |
| self.tan = nn.Tanh() |
| self.sigmoid = nn.Sigmoid() |
| self.relu = nn.ReLU(inplace=True) |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| conv_init(m) |
| elif isinstance(m, nn.BatchNorm2d): |
| bn_init(m, 1) |
| bn_init(self.bn, 1e-6) |
| for i in range(self.num_subset): |
| conv_branch_init(self.conv_d[i], self.num_subset) |
|
|
| def forward(self, x): |
| N, C, T, V = x.size() |
|
|
| y = None |
| if self.adaptive: |
| A = self.PA |
| |
| for i in range(self.num_subset): |
| A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_c * T) |
| A2 = self.conv_b[i](x).view(N, self.inter_c * T, V) |
| A1 = self.tan(torch.matmul(A1, A2) / A1.size(-1)) |
| A1 = A[i] + A1 * self.alpha |
| A2 = x.view(N, C * T, V) |
| z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V)) |
| y = z + y if y is not None else z |
| else: |
| A = self.A.cuda(x.get_device()) * self.mask |
| for i in range(self.num_subset): |
| A1 = A[i] |
| A2 = x.view(N, C * T, V) |
| z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V)) |
| y = z + y if y is not None else z |
|
|
| y = self.bn(y) |
| y += self.down(x) |
| y = self.relu(y) |
|
|
| if self.attention: |
| |
| se = y.mean(-2) |
| se1 = self.sigmoid(self.conv_sa(se)) |
| y = y * se1.unsqueeze(-2) + y |
| |
|
|
| |
| se = y.mean(-1) |
| se1 = self.sigmoid(self.conv_ta(se)) |
| y = y * se1.unsqueeze(-1) + y |
| |
|
|
| |
| se = y.mean(-1).mean(-1) |
| se1 = self.relu(self.fc1c(se)) |
| se2 = self.sigmoid(self.fc2c(se1)) |
| y = y * se2.unsqueeze(-1).unsqueeze(-1) + y |
| |
|
|
| |
| |
| |
| |
| return y |
|
|
|
|
| class TCN_GCN_unit(nn.Module): |
| def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, attention=True): |
| super(TCN_GCN_unit, self).__init__() |
| self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive, attention=attention) |
| self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride) |
| self.relu = nn.ReLU(inplace=True) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.attention = attention |
|
|
| if not residual: |
| self.residual = lambda x: 0 |
|
|
| elif (in_channels == out_channels) and (stride == 1): |
| self.residual = lambda x: x |
|
|
| else: |
| self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride) |
|
|
| def forward(self, x): |
| if self.attention: |
| y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x)) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| else: |
| y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x)) |
| return y |
|
|
|
|
| class Model(pl.LightningModule): |
| def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3, |
| drop_out=0, adaptive=True, attention=True, learning_rate=1e-4, weight_decay=1e-4): |
| super(Model, self).__init__() |
|
|
| |
| |
| |
| |
| self.graph = Graph(**graph_args) |
|
|
| A = self.graph.A |
| self.num_class = num_class |
|
|
| self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point) |
|
|
| self.l1 = TCN_GCN_unit(in_channels, 64, A, residual=False, adaptive=adaptive, attention=attention) |
| self.l2 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention) |
| self.l3 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention) |
| self.l4 = TCN_GCN_unit(64, 64, A, adaptive=adaptive, attention=attention) |
| self.l5 = TCN_GCN_unit(64, 128, A, stride=2, adaptive=adaptive, attention=attention) |
| self.l6 = TCN_GCN_unit(128, 128, A, adaptive=adaptive, attention=attention) |
| self.l7 = TCN_GCN_unit(128, 128, A, adaptive=adaptive, attention=attention) |
| self.l8 = TCN_GCN_unit(128, 256, A, stride=2, adaptive=adaptive, attention=attention) |
| self.l9 = TCN_GCN_unit(256, 256, A, adaptive=adaptive, attention=attention) |
| self.l10 = TCN_GCN_unit(256, 256, A, adaptive=adaptive, attention=attention) |
| |
| |
| |
|
|
| self.fc = nn.Linear(256, num_class) |
| nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class)) |
| bn_init(self.data_bn, 1) |
| if drop_out: |
| self.drop_out = nn.Dropout(drop_out) |
| else: |
| self.drop_out = lambda x: x |
| |
| self.loss = nn.CrossEntropyLoss() |
| self.metric = MulticlassAccuracy(num_class) |
| |
| self.learning_rate = learning_rate |
| self.weight_decay = weight_decay |
| self.validation_step_loss_outputs = [] |
| self.validation_step_acc_outputs = [] |
|
|
| self.save_hyperparameters() |
|
|
| def forward(self, x): |
| N, C, T, V, M = x.size() |
|
|
| x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T) |
| x = self.data_bn(x.float()) |
| x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V) |
|
|
| x = self.l1(x) |
| x = self.l2(x) |
| x = self.l3(x) |
| x = self.l4(x) |
| x = self.l5(x) |
| x = self.l6(x) |
| x = self.l7(x) |
| x = self.l8(x) |
| x = self.l9(x) |
| x = self.l10(x) |
| |
| |
| |
|
|
| |
| c_new = x.size(1) |
| x = x.view(N, M, c_new, -1) |
| x = x.mean(3).mean(1) |
| x = self.drop_out(x) |
|
|
| return self.fc(x) |
|
|
| def training_step(self, batch, batch_idx): |
| inputs, targets = batch |
| outputs = self(inputs) |
| y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1) |
| |
| |
| train_accuracy = self.metric(y_pred_class, targets) |
| loss = self.loss(outputs, targets) |
| self.log('train_accuracy', train_accuracy, prog_bar=True, on_epoch=True) |
| self.log('train_loss', loss, prog_bar=True, on_epoch=True) |
| |
| return loss |
| |
| def validation_step(self, batch, batch_idx): |
| inputs, targets = batch |
| outputs = self.forward(inputs) |
| y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1) |
| valid_accuracy = self.metric(y_pred_class, targets) |
| loss = self.loss(outputs, targets) |
| self.log('valid_accuracy', valid_accuracy, prog_bar=True, on_epoch=True) |
| self.log('valid_loss', loss, prog_bar=True, on_epoch=True) |
| self.validation_step_loss_outputs.append(loss) |
| self.validation_step_acc_outputs.append(valid_accuracy) |
| return {"valid_loss" : loss, "valid_accuracy" : valid_accuracy} |
| |
| def on_validation_epoch_end(self): |
| |
| |
| |
| |
| avg_loss = torch.stack(self.validation_step_loss_outputs).mean() |
| avg_acc = torch.stack(self.validation_step_acc_outputs).mean() |
| self.log("ptl/val_loss", avg_loss) |
| self.log("ptl/val_accuracy", avg_acc) |
| self.validation_step_loss_outputs.clear() |
| self.validation_step_acc_outputs.clear() |
|
|
| def test_step(self, batch, batch_idx): |
| inputs, targets = batch |
| outputs = self.forward(inputs) |
| y_pred_class = torch.argmax(torch.softmax(outputs, dim=1), dim=1) |
| print("Targets : ", targets) |
| print("Preds : ", y_pred_class) |
| test_accuracy = self.metric(y_pred_class, targets) |
| loss = self.loss(outputs, targets) |
| self.log('test_accuracy', test_accuracy, prog_bar=True, on_epoch=True) |
| self.log('test_loss', loss, prog_bar=True, on_epoch=True) |
| return {"test_loss" : loss, "test_accuracy" : test_accuracy} |
|
|
| def configure_optimizers(self): |
| params = self.parameters() |
| optimizer = optim.Adam(params=params, lr = self.learning_rate, weight_decay = self.weight_decay) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max') |
| return {"optimizer": optimizer, |
| "lr_scheduler": {"scheduler": scheduler, "monitor": "valid_accuracy"} |
| } |
| |
|
|
| def predict_step(self, batch, batch_idx): |
| return self(batch) |
|
|
| if __name__ == "__main__": |
| import os |
| from torchinfo import summary |
| print(os.getcwd()) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = Model(num_class=20, num_point=18, num_person=1, |
| graph_args={}, in_channels=2).to(device) |
| |
| |
| summary(model) |
| x = torch.randn((1, 2, 80, 18, 1)).to(device) |
| y = model(x) |
| print(y.shape) |