| import mxnet as mx |
| from mxnet.rnn import BaseRNNCell |
| from nowcasting.ops import activation |
| from nowcasting.operators.common import group_add |
|
|
| class MyBaseRNNCell(BaseRNNCell): |
| def __init__(self, prefix="MyBaseRNNCell", params=None): |
| super(MyBaseRNNCell, self).__init__(prefix=prefix, params=params) |
|
|
| def __call__(self, inputs, states, is_initial=False, ret_mid=False): |
| raise NotImplementedError() |
|
|
| def reset(self): |
| super(MyBaseRNNCell, self).reset() |
| self._curr_states = None |
|
|
| def get_current_states(self): |
| return self._curr_states |
|
|
| def unroll(self, length, inputs=None, begin_state=None, ret_mid=False, |
| input_prefix='', layout='TC', merge_outputs=False): |
| """Unroll an RNN cell across time steps. |
| |
| Parameters |
| ---------- |
| length : int |
| number of steps to unroll |
| inputs : Symbol, list of Symbol, or None |
| if inputs is a single Symbol (usually the output |
| of Embedding symbol), it should have shape |
| (batch_size, length, ...) if layout == 'NTC', |
| or (length, batch_size, ...) if layout == 'TNC'. |
| |
| If inputs is a list of symbols (usually output of |
| previous unroll), they should all have shape |
| (batch_size, ...). |
| |
| If inputs is None, Placeholder variables are |
| automatically created. |
| begin_state : nested list of Symbol |
| input states. Created by begin_state() |
| or output state of another cell. Created |
| from begin_state() if None. |
| input_prefix : str |
| prefix for automatically created input |
| placehodlers. |
| layout : str |
| layout of input symbol. Only used if inputs |
| is a single Symbol. |
| merge_outputs : bool |
| if False, return outputs as a list of Symbols. |
| If True, concatenate output across time steps |
| and return a single symbol with shape |
| (batch_size, length, ...) if layout == 'NTC', |
| or (length, batch_size, ...) if layout == 'TNC'. |
| |
| Returns |
| ------- |
| outputs : list of Symbol |
| output symbols. |
| states : Symbol or nested list of Symbol |
| has the same structure as begin_state() |
| mid_info : list of Symbol |
| """ |
| self.reset() |
| assert layout == 'TNC' or layout == 'TC' |
| if inputs is not None: |
| if isinstance(inputs, mx.sym.Symbol): |
| assert len(inputs.list_outputs()) == 1, \ |
| "unroll doesn't allow grouped symbol as input. Please " \ |
| "convert to list first or let unroll handle slicing" |
| if 'N' in layout: |
| inputs = mx.sym.SliceChannel(inputs, axis=0, num_outputs=length, |
| squeeze_axis=1) |
| else: |
| inputs = mx.sym.SliceChannel(inputs, axis=0, num_outputs=length) |
| else: |
| assert len(inputs) == length |
| else: |
| inputs = [None] * length |
| if begin_state is None: |
| states = self.begin_state() |
| else: |
| states = begin_state |
| outputs = [] |
| mid_infos = [] |
| for i in range(length): |
| output, states, mid_info = self(inputs=inputs[i], states=states, |
| is_initial=(i == 0 and (begin_state is None)), |
| ret_mid=True) |
| outputs.append(output) |
| mid_infos.extend(mid_info) |
| if merge_outputs: |
| outputs = [mx.sym.expand_dims(i, axis=0) for i in outputs] |
| outputs = mx.sym.Concat(*outputs, dim=0) |
| if ret_mid: |
| return outputs, states, mid_infos |
| else: |
| return outputs, states |
|
|
|
|
| class BaseStackRNN(object): |
| def __init__(self, base_rnn_class, stack_num=1, |
| name="BaseStackRNN", residual_connection=True, |
| **kwargs): |
| self._base_rnn_class = base_rnn_class |
| self._residual_connection = residual_connection |
| self._name = name |
| self._stack_num = stack_num |
| self._prefix = name + "_" |
| self._rnns = [base_rnn_class(prefix=self._name + "_%d" %i, **kwargs) for i in range(stack_num)] |
| self._init_counter = 0 |
| self._state_info = None |
|
|
| def init_state_vars(self): |
| """Initial state variable for this cell. |
| |
| Parameters |
| ---------- |
| |
| Returns |
| ------- |
| state_vars : nested list of Symbol |
| starting states for first RNN step |
| """ |
| state_vars = [] |
| for i, info in enumerate(self.state_info): |
| state = mx.sym.var(name='%s_begin_state_%s' % (self._name, self.state_postfix[i]), **info) |
| state_vars.append(state) |
| return state_vars |
|
|
| def concat_to_split(self, concat_states): |
| assert len(concat_states) == len(self.state_info) |
| split_states = [[] for i in range(self._stack_num)] |
| for i in range(len(self.state_info)): |
| channel_axis = self.state_info[i]['__layout__'].lower().find('c') |
| ele = mx.sym.split(concat_states[i], num_outputs=self._stack_num, axis=channel_axis) |
| for j in range(self._stack_num): |
| split_states[j].append(ele[j]) |
| return split_states |
|
|
| def split_to_concat(self, split_states): |
| |
| concat_states = [] |
| for i in range(len(self.state_info)): |
| channel_axis = self.state_info[i]['__layout__'].lower().find('c') |
| concat_states.append(mx.sym.concat(*[ele[i] for ele in split_states], |
| dim=channel_axis)) |
| return concat_states |
|
|
| def check_concat(self, states): |
| ret = not isinstance(states[0], list) |
| return ret |
|
|
| def to_concat(self, states): |
| if not self.check_concat(states): |
| states = self.split_to_concat(states) |
| return states |
|
|
| def to_split(self, states): |
| if self.check_concat(states): |
| states = self.concat_to_split(states) |
| return states |
|
|
| @property |
| def state_postfix(self): |
| return self._rnns[0].state_postfix |
|
|
| @property |
| def state_info(self): |
| if self._state_info is None: |
| info = [] |
| for i in range(len(self._rnns[0].state_info)): |
| ele = {} |
| for rnn in self._rnns: |
| if 'shape' not in ele: |
| ele['shape'] = list(rnn.state_info[i]['shape']) |
| else: |
| channel_dim = rnn.state_info[i]['__layout__'].lower().find('c') |
| ele['shape'][channel_dim] += rnn.state_info[i]['shape'][channel_dim] |
| if '__layout__' not in ele: |
| ele['__layout__'] = rnn.state_info[i]['__layout__'].upper() |
| else: |
| assert rnn.state_info[i]['__layout__'] == ele['__layout__'].upper() |
| ele['shape'] = tuple(ele['shape']) |
| info.append(ele) |
| self._state_info = info |
| return info |
| else: |
| return self._state_info |
|
|
| def flatten_add_layout(self, states, blocked=False): |
| """ |
| |
| Parameters |
| ---------- |
| states : list of list or list |
| |
| Returns |
| ------- |
| ret : list |
| """ |
| states = self.to_concat(states) |
| assert self.check_concat(states) |
| ret = [] |
| for i, ele in enumerate(states): |
| if blocked: |
| ret.append(mx.sym.BlockGrad(ele, __layout__=self.state_info[i]['__layout__'])) |
| else: |
| ele._set_attr(__layout__=self.state_info[i]['__layout__']) |
| ret.append(ele) |
| return ret |
|
|
| def reset(self): |
| for i in range(len(self._rnns)): |
| self._rnns[i].reset() |
|
|
| def unroll(self, length, inputs=None, begin_states=None, ret_mid=False): |
| if begin_states is None: |
| begin_states = self.init_state_vars() |
| begin_states = self.to_split(begin_states) |
| assert len(begin_states) == self._stack_num, len(begin_states) |
| for ele in begin_states: |
| assert len(ele) == len(self.state_info) |
| outputs = [] |
| final_states = [] |
| mid_infos = [] |
| for i in range(len(self._rnns)): |
| rnn_out_list, rnn_final_states, rnn_mid_infos =\ |
| self._rnns[i].unroll(length=length, inputs=inputs, |
| begin_state=begin_states[i], |
| layout="TC", |
| ret_mid=True) |
| if self._residual_connection and i > 0: |
| |
| rnn_out_list = group_add(lhs=rnn_out_list, rhs=inputs) |
| inputs = rnn_out_list |
| outputs.append(rnn_out_list) |
| final_states.append(rnn_final_states) |
| mid_infos.append(rnn_mid_infos) |
| if ret_mid: |
| return outputs, final_states, mid_infos |
| else: |
| return outputs, final_states |
|
|
|
|
| class MyGRU(MyBaseRNNCell): |
| """GRU cell. |
| |
| Parameters |
| ---------- |
| num_hidden : int |
| number of units in output symbol |
| prefix : str, default 'rnn_' |
| prefix for name of layers |
| (and name of weight if params is None) |
| params : RNNParams or None |
| container for weight sharing between cells. |
| created if None. |
| """ |
| def __init__(self, num_hidden, zoneout=0.0, act_type="tanh", prefix='gru_', params=None): |
| super(MyGRU, self).__init__(prefix=prefix, params=params) |
| self._num_hidden = num_hidden |
| self._act_type = act_type |
| self._zoneout = zoneout |
| self._i2h_weight = self.params.get('i2h_weight') |
| self._i2h_bias = self.params.get('i2h_bias') |
| self._h2h_weight = self.params.get('h2h_weight') |
| self._h2h_bias = self.params.get('h2h_bias') |
|
|
| @property |
| def state_info(self): |
| """shape(s) of states""" |
| return [{'shape': (0, self._num_hidden), '__layout__': "NC"}] |
|
|
| def __call__(self, inputs, states=None, is_initial=False, ret_mid=False): |
| name = '%s_t%d' % (self._prefix, self._counter) |
| self._counter += 1 |
| if is_initial: |
| prev_h = self.begin_state()[0] |
| else: |
| prev_h = states[0] |
| assert states is not None |
| if inputs is not None: |
| inputs = mx.sym.reshape(inputs, shape=(0, -1)) |
| i2h = mx.sym.FullyConnected(data=inputs, |
| num_hidden=self._num_hidden * 3, |
| weight=self._i2h_weight, |
| bias=self._i2h_bias, |
| name="%s_i2h" %name) |
| i2h_slice = mx.sym.SliceChannel(i2h, num_outputs=3, axis=1) |
| else: |
| i2h_slice = None |
| h2h = mx.sym.FullyConnected(data=prev_h, |
| num_hidden=self._num_hidden * 3, |
| weight=self._h2h_weight, |
| bias=self._h2h_bias, |
| name="%s_h2h" %name) |
| h2h_slice = mx.sym.SliceChannel(h2h, num_outputs=3, axis=1) |
| if i2h_slice is not None: |
| reset_gate = activation(i2h_slice[0] + h2h_slice[0], act_type="sigmoid", |
| name=name + "_r") |
| update_gate = activation(i2h_slice[1] + h2h_slice[1], act_type="sigmoid", |
| name=name + "_u") |
| new_mem = activation(i2h_slice[2] + reset_gate * h2h_slice[2], |
| act_type=self._act_type, |
| name=name + "_h") |
| else: |
| reset_gate = activation(h2h_slice[0], act_type="sigmoid", |
| name=name + "_r") |
| update_gate = activation(h2h_slice[1], act_type="sigmoid", |
| name=name + "_u") |
| new_mem = activation(reset_gate * h2h_slice[2], |
| act_type=self._act_type, |
| name=name + "_h") |
| next_h = update_gate * prev_h + (1 - update_gate) * new_mem |
| if self._zoneout > 0.0: |
| mask = mx.sym.Dropout(mx.sym.ones_like(prev_h), p=self._zoneout) |
| next_h = mx.sym.where(mask, next_h, prev_h) |
| self._curr_states = [next_h] |
| if not ret_mid: |
| return next_h, [next_h] |
| else: |
| return next_h, [next_h], [] |
|
|
| if __name__ == '__main__': |
| from nowcasting.operators.conv_rnn import ConvGRU |
| brnn1 = BaseStackRNN(base_rnn_class=ConvGRU, stack_num=5, |
| b_h_w=(4, 32, 32), num_filter=32) |
| print(brnn1.state_info) |
| inputs = mx.sym.var(name="inputs", shape=(8, 4, 16, 32, 32)) |
| outputs, final_states, mid_infos = brnn1.unroll(length=8, inputs=inputs, ret_mid=True) |
| print(len(outputs), len(outputs[0])) |
| print(len(final_states), len(final_states[0])) |