| import mxnet as mx |
| from nowcasting.config import cfg |
| from nowcasting.ops import reset_regs |
| from nowcasting.operators.common import grid_generator |
|
|
|
|
| class PredictionBaseFactory(object): |
| def __init__(self, batch_size, in_seq_len, out_seq_len, height, width, name="forecaster"): |
| self._out_typ = cfg.MODEL.OUT_TYPE |
| self._batch_size = batch_size |
| self._in_seq_len = in_seq_len |
| self._out_seq_len = out_seq_len |
| self._height = height |
| self._width = width |
| self._name = name |
| self._spatial_grid = grid_generator(batch_size=batch_size, height=height, width=width) |
| self.rnn_list = self._init_rnn() |
| self._reset_rnn() |
|
|
| def _pre_encode_frame(self, frame_data, seqlen): |
| ret = mx.sym.Concat(frame_data, |
| mx.sym.broadcast_to(mx.sym.expand_dims(self._spatial_grid, axis=0), |
| shape=(seqlen, self._batch_size, |
| 2, self._height, self._width)), |
| mx.sym.ones(shape=(seqlen, self._batch_size, 1, |
| self._height, self._width)), |
| num_args=3, dim=2) |
| return ret |
|
|
| def _init_rnn(self): |
| raise NotImplementedError |
|
|
| def _reset_rnn(self): |
| for rnn in self.rnn_list: |
| rnn.reset() |
|
|
| def reset_all(self): |
| reset_regs() |
| self._reset_rnn() |
|
|
|
|
| class RecursiveOneStepBaseFactory(PredictionBaseFactory): |
| def __init__(self, batch_size, in_seq_len, out_seq_len, height, width, use_ss=False, |
| name="forecaster"): |
| super(RecursiveOneStepBaseFactory, self).__init__(batch_size=batch_size, |
| in_seq_len=in_seq_len, |
| out_seq_len=out_seq_len, |
| height=height, |
| width=width, |
| name=name) |
| self._use_ss = False |
|
|
|
|