| import mxnet as mx |
| import logging |
| from nowcasting.ops import * |
| from nowcasting.operators.common import identity, grid_generator, group_add, constant, save_npy |
| from nowcasting.operators.conv_rnn import BaseConvRNN |
| import numpy as np |
|
|
|
|
| def flow_conv(data, num_filter, flows, weight, bias, name): |
| assert isinstance(flows, list) |
| warpped_data = [] |
| for i in range(len(flows)): |
| flow = flows[i] |
| grid = mx.sym.GridGenerator(data=-flow, transform_type="warp") |
| ele_dat = mx.sym.BilinearSampler(data=data, grid=grid) |
| warpped_data.append(ele_dat) |
| data = mx.sym.concat(*warpped_data, dim=1) |
| ret = mx.sym.Convolution(data=data, |
| num_filter=num_filter, |
| kernel=(1, 1), |
| weight=weight, |
| bias=bias, |
| name=name) |
| return ret |
|
|
|
|
| class TrajGRU(BaseConvRNN): |
| def __init__(self, b_h_w, num_filter, zoneout=0.0, L=5, |
| i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), |
| h2h_kernel=(5, 5), h2h_dilate=(1, 1), |
| act_type="leaky", |
| prefix="TrajGRU", lr_mult=1.0): |
| super(TrajGRU, self).__init__(num_filter=num_filter, |
| b_h_w=b_h_w, |
| h2h_kernel=h2h_kernel, |
| h2h_dilate=h2h_dilate, |
| i2h_kernel=i2h_kernel, |
| i2h_pad=i2h_pad, |
| i2h_stride=i2h_stride, |
| act_type=act_type, |
| prefix=prefix) |
| self._L = L |
| self._zoneout = zoneout |
| self.i2f_conv1_weight = self.params.get("i2f_conv1_weight", lr_mult=lr_mult) |
| self.i2f_conv1_bias = self.params.get("i2f_conv1_bias", lr_mult=lr_mult) |
| self.h2f_conv1_weight = self.params.get("h2f_conv1_weight", lr_mult=lr_mult) |
| self.h2f_conv1_bias = self.params.get("h2f_conv1_bias", lr_mult=lr_mult) |
| self.f_conv2_weight = self.params.get("f_conv2_weight", lr_mult=lr_mult) |
| self.f_conv2_bias = self.params.get("f_conv2_bias", lr_mult=lr_mult) |
| if cfg.MODEL.TRAJRNN.INIT_GRID: |
| logging.info("TrajGRU: Initialize Grid Using Zeros!") |
| self.f_out_weight = self.params.get("f_out_weight", |
| lr_mult=lr_mult * cfg.MODEL.TRAJRNN.FLOW_LR_MULT, |
| init=mx.init.Zero()) |
| self.f_out_bias = self.params.get("f_out_bias", |
| lr_mult=lr_mult * cfg.MODEL.TRAJRNN.FLOW_LR_MULT, |
| init=mx.init.Zero()) |
| else: |
| self.f_out_weight = self.params.get("f_out_weight", lr_mult=lr_mult) |
| self.f_out_bias = self.params.get("f_out_bias", lr_mult=lr_mult) |
| self.i2h_weight = self.params.get("i2h_weight", lr_mult=lr_mult) |
| self.i2h_bias = self.params.get("i2h_bias", lr_mult=lr_mult) |
| self.h2h_weight = self.params.get("h2h_weight", lr_mult=lr_mult) |
| self.h2h_bias = self.params.get("h2h_bias", lr_mult=lr_mult) |
|
|
| @property |
| def state_postfix(self): |
| return ['h'] |
|
|
| @property |
| def state_info(self): |
| return [{'shape': (self._batch_size, self._num_filter, |
| self._state_height, self._state_width), |
| '__layout__': "NCHW"}] |
|
|
| def _flow_generator(self, inputs, states, prefix): |
| if inputs is not None: |
| i2f_conv1 = mx.sym.Convolution(data=inputs, |
| weight=self.i2f_conv1_weight, |
| bias=self.i2f_conv1_bias, |
| kernel=(5, 5), |
| dilate=(1, 1), |
| pad=(2, 2), |
| num_filter=32, |
| name="%s_i2f_conv1" % prefix) |
| else: |
| i2f_conv1 = None |
| h2f_conv1 = mx.sym.Convolution(data=states, |
| weight=self.h2f_conv1_weight, |
| bias=self.h2f_conv1_bias, |
| kernel=(5, 5), |
| dilate=(1, 1), |
| pad=(2, 2), |
| num_filter=32, |
| name="%s_h2f_conv1" % prefix) |
| f_conv1 = i2f_conv1 + h2f_conv1 if i2f_conv1 is not None else h2f_conv1 |
| f_conv1 = activation(f_conv1, act_type=self._act_type) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| flows = mx.sym.Convolution(data=f_conv1, |
| weight=self.f_out_weight, |
| bias=self.f_out_bias, |
| kernel=(5, 5), |
| pad=(2, 2), |
| num_filter=self._L * 2) |
| if cfg.MODEL.TRAJRNN.SAVE_MID_RESULTS: |
| import os |
| flows = save_npy(flows, save_name="%s_flow" %prefix, |
| save_dir=os.path.join(cfg.MODEL.SAVE_DIR, "flows")) |
| flows = mx.sym.split(flows, num_outputs=self._L, axis=1) |
| flows = [flows[i] for i in range(self._L)] |
| return flows |
|
|
| def __call__(self, inputs, states=None, is_initial=False, ret_mid=False): |
| self._counter += 1 |
| name = '%s_t%d' % (self._prefix, self._counter) |
| if is_initial: |
| states = self.begin_state()[0] |
| else: |
| states = states[0] |
| assert states is not None |
| if inputs is not None: |
| i2h = mx.sym.Convolution(data=inputs, |
| weight=self.i2h_weight, |
| bias=self.i2h_bias, |
| kernel=self._i2h_kernel, |
| stride=self._i2h_stride, |
| dilate=self._i2h_dilate, |
| pad=self._i2h_pad, |
| num_filter=self._num_filter * 3, |
| name="%s_i2h" % name) |
| i2h_slice = mx.sym.SliceChannel(i2h, num_outputs=3, axis=1) |
| else: |
| i2h_slice = None |
| prev_h = states |
| flows = self._flow_generator(inputs=inputs, states=states, prefix=name) |
| |
| h2h = flow_conv(data=prev_h, num_filter=self._num_filter * 3, flows=flows, |
| weight=self.h2h_weight, bias=self.h2h_bias, name="%s_h2h" % name) |
| h2h_slice = mx.sym.SliceChannel(h2h, num_outputs=3, axis=1) |
| if i2h_slice is not None: |
| reset_gate = mx.sym.Activation(i2h_slice[0] + h2h_slice[0], act_type="sigmoid", |
| name=name + "_r") |
| update_gate = mx.sym.Activation(i2h_slice[1] + h2h_slice[1], act_type="sigmoid", |
| name=name + "_u") |
| new_mem = activation(i2h_slice[2] + reset_gate * h2h_slice[2], |
| act_type=self._act_type, |
| name=name + "_h") |
| else: |
| reset_gate = mx.sym.Activation(h2h_slice[0], act_type="sigmoid", |
| name=name + "_r") |
| update_gate = mx.sym.Activation(h2h_slice[1], act_type="sigmoid", |
| name=name + "_u") |
| new_mem = activation(reset_gate * h2h_slice[2], |
| act_type=self._act_type, |
| name=name + "_h") |
| next_h = update_gate * prev_h + (1 - update_gate) * new_mem |
| if self._zoneout > 0.0: |
| mask = mx.sym.Dropout(mx.sym.ones_like(prev_h), p=self._zoneout) |
| next_h = mx.sym.where(mask, next_h, prev_h) |
| self._curr_states = [next_h] |
| if not ret_mid: |
| return next_h, [next_h] |
| else: |
| return next_h, [next_h], [] |
|
|