| import mxnet as mx |
| from nowcasting.ops import * |
| from nowcasting.operators.common import identity, grid_generator, group_add |
| from nowcasting.operators.base_rnn import MyBaseRNNCell |
| import numpy as np |
|
|
|
|
|
|
| class BaseConvRNN(MyBaseRNNCell): |
| def __init__(self, num_filter, b_h_w, |
| h2h_kernel=(3, 3), h2h_dilate=(1, 1), |
| i2h_kernel=(3, 3), i2h_stride=(1, 1), |
| i2h_pad=(1, 1), i2h_dilate=(1, 1), |
| act_type="tanh", prefix="ConvRNN", params=None): |
| super(BaseConvRNN, self).__init__(prefix=prefix + "_", params=params) |
| self._num_filter = num_filter |
| self._h2h_kernel = h2h_kernel |
| assert (self._h2h_kernel[0] % 2 == 1) and (self._h2h_kernel[1] % 2 == 1), \ |
| "Only support odd number, get h2h_kernel= %s" % str(h2h_kernel) |
| self._h2h_pad = (h2h_dilate[0] * (h2h_kernel[0] - 1) // 2, |
| h2h_dilate[1] * (h2h_kernel[1] - 1) // 2) |
| self._h2h_dilate = h2h_dilate |
| self._i2h_kernel = i2h_kernel |
| self._i2h_stride = i2h_stride |
| self._i2h_pad = i2h_pad |
| self._i2h_dilate = i2h_dilate |
| self._act_type = act_type |
| assert len(b_h_w) == 3 |
| i2h_dilate_ksize_h = 1 + (self._i2h_kernel[0] - 1) * self._i2h_dilate[0] |
| i2h_dilate_ksize_w = 1 + (self._i2h_kernel[1] - 1) * self._i2h_dilate[1] |
| self._batch_size, self._height, self._width = b_h_w |
| self._state_height = (self._height + 2 * self._i2h_pad[0] - i2h_dilate_ksize_h)\ |
| // self._i2h_stride[0] + 1 |
| self._state_width = (self._width + 2 * self._i2h_pad[1] - i2h_dilate_ksize_w) \ |
| // self._i2h_stride[1] + 1 |
| print(self._prefix, self._state_height, self._state_width) |
| self._curr_states = None |
| self._counter = 0 |
|
|
|
|
| class ConvRNN(BaseConvRNN): |
| def __init__(self, num_filter, b_h_w, |
| h2h_kernel=(3, 3), h2h_dilate=(1, 1), |
| i2h_kernel=(3, 3), i2h_stride=(1, 1), |
| i2h_pad=(1, 1), i2h_dilate=(1, 1), |
| act_type="leaky", |
| layer_norm=False, |
| prefix="ConvRNN", |
| params=None): |
| super(ConvRNN, self).__init__(num_filter=num_filter, |
| b_h_w=b_h_w, |
| h2h_kernel=h2h_kernel, |
| h2h_dilate=h2h_dilate, |
| i2h_kernel=i2h_kernel, |
| i2h_pad=i2h_pad, |
| i2h_dilate=i2h_dilate, |
| act_type=act_type, |
| prefix=prefix, |
| params=params) |
| self._layer_norm = layer_norm |
| self.i2h_weight = self.params.get('i2h_weight') |
| self.i2h_bias = self.params.get('i2h_bias') |
| self.h2h_weight = self.params.get('h2h_weight') |
| self.h2h_bias = self.params.get('h2h_bias', init=mx.init.Normal()) |
|
|
| @property |
| def state_info(self): |
| return [{'shape': (self._batch_size, self._num_filter, |
| self._state_height, self._state_width), |
| '__layout__': "NCHW"}] |
|
|
| def __call__(self, inputs, states=None, is_initial=False, ret_mid=False): |
| name = '%s_t%d' % (self._prefix, self._counter) |
| self._counter += 1 |
| states = self.begin_state()[0] if is_initial else states[0] |
| assert states is not None |
| if inputs is not None: |
| i2h = mx.sym.Convolution(data=inputs, |
| weight=self.i2h_weight, |
| bias=self.i2h_bias, |
| kernel=self._i2h_kernel, |
| stride=self._i2h_stride, |
| dilate=self._i2h_dilate, |
| pad=self._i2h_pad, |
| num_filter=self._num_filter, |
| name="%s_i2h" % name) |
| else: |
| i2h = None |
| h2h = mx.sym.Convolution(data=states, |
| weight=self.h2h_weight, |
| bias=self.h2h_bias, |
| kernel=self._h2h_kernel, |
| stride=(1, 1), |
| dilate=self._h2h_dilate, |
| pad=self._h2h_pad, |
| num_filter=self._num_filter, |
| name="%s_h2h" % name) |
| if i2h is not None: |
| if self._layer_norm: |
| next_h = activation(layer_normalization(i2h + h2h, |
| num_filters=self._num_filter, |
| name=self._prefix + "ln"), |
| act_type=self._act_type, name=name + "_state") |
| else: |
| next_h = activation(i2h + h2h, |
| act_type=self._act_type, name=name + "_state") |
| else: |
| if self._layer_norm: |
| next_h = activation(layer_normalization(h2h, |
| num_filters=self._num_filter, |
| name=self._prefix + "ln"), |
| act_type=self._act_type, name=name + "_state") |
| else: |
| next_h = activation(h2h, act_type=self._act_type, name=name + "_state") |
| |
| self._curr_states = [next_h] |
| if not ret_mid: |
| return next_h, [next_h] |
| else: |
| return next_h, [next_h], [i2h, h2h] |
|
|
|
|
| class ConvGRU(BaseConvRNN): |
| def __init__(self, num_filter, b_h_w, zoneout=0.0, |
| h2h_kernel=(3, 3), h2h_dilate=(1, 1), |
| i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1), |
| i2h_adj=(0, 0), no_i2h_bias=False, use_deconv=False, |
| act_type="leaky", prefix="ConvGRU", lr_mult=1.0): |
| """Initializing a ConvGRU/DeconvGRU |
| |
| r_t = \sigma(W_r \ast x_t + R_r \ast h_{t-1} + b_{W_r} + b_{R_r}) |
| u_t = \sigma(W_u \ast x_t + R_u \ast h_{t-1} + b_{W_u} + b_{R_u}) |
| h^\prime_t = tanh(W_h \ast x_t + r_t \circ (R_h \ast h_{t-1} + b_{R_h}) + b_{W_h}) |
| h_t = (1 - u_t) \circ h^\prime_t + u_t \circ h_{t-1} |
| |
| Parameters: (reset_gate, update_gate, new_mem) |
| W_{i2h} = [W_r, W_u, W_h] |
| b_{i2h} = [b_{W_r}, b_{W_u}, b_{W_h}] |
| W_{h2h} = [R_r, R_u, R_h] |
| b_{h2h} = [b_{R_r}, b_{R_u}, b_{R_h}] |
| |
| |
| Parameters |
| ---------- |
| num_hidden : int |
| hidden_act_type : str |
| name : str |
| """ |
| super(ConvGRU, self).__init__(num_filter=num_filter, |
| b_h_w=b_h_w, |
| h2h_kernel=h2h_kernel, |
| h2h_dilate=h2h_dilate, |
| i2h_kernel=i2h_kernel, |
| i2h_pad=i2h_pad, |
| i2h_stride=i2h_stride, |
| i2h_dilate=i2h_dilate, |
| act_type=act_type, |
| prefix=prefix) |
| self._no_i2h_bias = no_i2h_bias |
| self._i2h_adj = i2h_adj |
| self._use_deconv = use_deconv |
| if self._no_i2h_bias: |
| assert use_deconv |
| self._zoneout = zoneout |
| self.i2h_weight = self.params.get("i2h_weight", lr_mult=lr_mult) |
| self.i2h_bias = self.params.get("i2h_bias", lr_mult=lr_mult) |
| self.h2h_weight = self.params.get("h2h_weight", lr_mult=lr_mult) |
| self.h2h_bias = self.params.get("h2h_bias", lr_mult=lr_mult) |
|
|
| @property |
| def state_postfix(self): |
| return ['h'] |
|
|
| @property |
| def state_info(self): |
| return [{'shape': (self._batch_size, self._num_filter, |
| self._state_height, self._state_width), |
| '__layout__': "NCHW"}] |
|
|
| def __call__(self, inputs, states=None, is_initial=False, ret_mid=False): |
| name = '%s_t%d' % (self._prefix, self._counter) |
| self._counter += 1 |
| if is_initial: |
| states = self.begin_state()[0] |
| else: |
| states = states[0] |
| assert states is not None |
| if inputs is not None: |
| if self._use_deconv: |
| if self._no_i2h_bias: |
| i2h = mx.sym.Deconvolution(data=inputs, |
| weight=self.i2h_weight, |
| kernel=self._i2h_kernel, |
| stride=self._i2h_stride, |
| pad=self._i2h_pad, |
| adj=self._i2h_adj, |
| no_bias=True, |
| num_filter=self._num_filter * 3, |
| name="%s_i2h" % name) |
| else: |
| i2h = mx.sym.Deconvolution(data=inputs, |
| weight=self.i2h_weight, |
| bias=self.i2h_bias, |
| kernel=self._i2h_kernel, |
| stride=self._i2h_stride, |
| pad=self._i2h_pad, |
| adj=self._i2h_adj, |
| num_filter=self._num_filter * 3, |
| name="%s_i2h" % name) |
| else: |
| i2h = mx.sym.Convolution(data=inputs, |
| weight=self.i2h_weight, |
| bias=self.i2h_bias, |
| kernel=self._i2h_kernel, |
| stride=self._i2h_stride, |
| dilate=self._i2h_dilate, |
| pad=self._i2h_pad, |
| num_filter=self._num_filter * 3, |
| name="%s_i2h" % name) |
| i2h_slice = mx.sym.SliceChannel(i2h, num_outputs=3, axis=1) |
| else: |
| i2h_slice = None |
| prev_h = states |
| print("h2h_dilate=", self._h2h_dilate) |
| h2h = mx.sym.Convolution(data=prev_h, |
| weight=self.h2h_weight, |
| bias=self.h2h_bias, |
| no_bias=False, |
| kernel=self._h2h_kernel, |
| stride=(1, 1), |
| dilate=self._h2h_dilate, |
| pad=self._h2h_pad, |
| num_filter=self._num_filter * 3, |
| name="%s_h2h" % name) |
| h2h_slice = mx.sym.SliceChannel(h2h, num_outputs=3, axis=1) |
| if i2h_slice is not None: |
| reset_gate = mx.sym.Activation(i2h_slice[0] + h2h_slice[0], act_type="sigmoid", |
| name=name + "_r") |
| update_gate = mx.sym.Activation(i2h_slice[1] + h2h_slice[1], act_type="sigmoid", |
| name=name + "_u") |
| new_mem = activation(i2h_slice[2] + reset_gate * h2h_slice[2], |
| act_type=self._act_type, |
| name=name + "_h") |
| else: |
| reset_gate = mx.sym.Activation(h2h_slice[0], act_type="sigmoid", |
| name=name + "_r") |
| update_gate = mx.sym.Activation(h2h_slice[1], act_type="sigmoid", |
| name=name + "_u") |
| new_mem = activation(reset_gate * h2h_slice[2], |
| act_type=self._act_type, |
| name=name + "_h") |
| next_h = update_gate * prev_h + (1 - update_gate) * new_mem |
| if self._zoneout > 0.0: |
| mask = mx.sym.Dropout(mx.sym.ones_like(prev_h), p=self._zoneout) |
| next_h = mx.sym.where(mask, next_h, prev_h) |
| self._curr_states = [next_h] |
| if not ret_mid: |
| return next_h, [next_h] |
| else: |
| return next_h, [next_h], [] |
|
|
|
|
| if __name__ == '__main__': |
| import numpy as np |
|
|
| |
| data = mx.sym.Variable('data') |
| data = mx.sym.SliceChannel(data, axis=0, num_outputs=11, squeeze_axis=True) |
| conv_gru1 = ConvGRU(num_filter=100, b_h_w=(3, 40, 40), |
| prefix="conv_gru1") |
| out, states = conv_gru1(inputs=data[0], is_initial=True) |
| for i in range(1, 11): |
| out, states = conv_gru1(inputs=data[i], states=states) |
| conv_gru_forward_backward_time =\ |
| mx.test_utils.check_speed(out, |
| location={'data': np.random.normal(size=(11, 3, 128, 40, 40))}, |
| N=2) |
| net = mx.mod.Module(out, data_names=['data',], label_names=None, context=mx.gpu()) |
| net.bind(data_shapes=[('data', (11, 3, 128, 40, 40))], |
| grad_req='add') |
| net.init_params() |
| net.forward(mx.io.DataBatch(data=[mx.random.normal(shape=(11, 3, 128, 40, 40))], label=None), is_train=False) |
| print(net.get_outputs()[0].asnumpy()) |
|
|
| |
| data = mx.sym.Variable('data') |
| data = mx.sym.SliceChannel(data, axis=0, num_outputs=11, squeeze_axis=True) |
| conv_rnn1 = ConvRNN(num_filter=100, b_h_w=(3, 40, 40), |
| prefix="conv_rnn1") |
| out, states = conv_rnn1(inputs=data[0], is_initial=True) |
| for i in range(1, 11): |
| out, states = conv_rnn1(inputs=data[i], states=states) |
| conv_rnn_forward_backward_time = \ |
| mx.test_utils.check_speed(out, |
| location={'data': np.random.normal(size=(11, 3, 128, 40, 40))}, |
| N=2) |
| net = mx.mod.Module(out, data_names=['data', ], label_names=None, context=mx.gpu()) |
| net.bind(data_shapes=[('data', (11, 3, 128, 40, 40))], |
| grad_req='add') |
| net.init_params() |
| net.forward(mx.io.DataBatch(data=[mx.random.normal(shape=(11, 3, 128, 40, 40))], label=None), |
| is_train=False) |
| print(net.get_outputs()[0].asnumpy()) |
|
|
| print("ConvGRU Time:", conv_gru_forward_backward_time) |
| print("ConvRNN Time:", conv_rnn_forward_backward_time) |
|
|