| import mxnet as mx |
| from nowcasting.ops import \ |
| conv3d, conv3d_act, conv3d_bn_act, \ |
| conv2d, conv2d_act, conv2d_bn_act, \ |
| deconv3d, deconv3d_act, deconv3d_bn_act, \ |
| deconv2d, deconv2d_act, deconv2d_bn_act, \ |
| fc_layer, fc_layer_act |
| from nowcasting.config import cfg |
|
|
|
|
| |
| def encode_net_symbol(data, |
| data_type, |
| no_bias=False, |
| momentum=0.9, |
| fix_gamma=False, |
| eps=1e-5 + 1e-12, |
| postfix=""): |
| """Construct encode_net symbol. |
| |
| Args: |
| data: input data (context or pred) |
| data_type: If "context" use IN_LEN, if "pred" use OUT_LEN, if |
| "contextpred" use IN_LEN + OUT_LEN. |
| postfix: Postfix for symbol names. Parameters will be shared with |
| between symbols created during calls to encode_net_symbol with same |
| data_type and postfix argument, |
| """ |
|
|
| if cfg.DATASET == "MOVINGMNIST": |
| IN_LEN = cfg.MOVINGMNIST.IN_LEN |
| OUT_LEN = cfg.MOVINGMNIST.OUT_LEN |
| IMG_SIZE = cfg.MOVINGMNIST.IMG_SIZE |
| elif cfg.DATASET == "HKO": |
| IN_LEN = cfg.HKO.BENCHMARK.IN_LEN |
| OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN |
| IMG_SIZE = cfg.HKO.ITERATOR.WIDTH |
|
|
| |
| |
|
|
| |
| if data_type == "context": |
| length = IN_LEN |
| elif data_type == "pred": |
| length = OUT_LEN |
| elif data_type == "contextpred": |
| length = IN_LEN + OUT_LEN |
| else: |
| raise NotImplementedError |
|
|
| |
| postfix = "_" + data_type + "_" + postfix |
|
|
| if not cfg.MODEL.DECONVBASELINE.USE_3D: |
| data = mx.sym.reshape( |
| data, |
| shape=(cfg.MODEL.TRAIN.BATCH_SIZE, length, IMG_SIZE, IMG_SIZE)) |
|
|
| |
| if cfg.DATASET == "MOVINGMNIST": |
| assert (length in [1, 10, 11, 20]) |
| elif cfg.DATASET == "HKO": |
| assert (length in [1, 5, 20, 21, 25]) |
|
|
| k = [1, 1, 1] |
| s = [1, 1, 1] |
| p = [0, 0, 0] |
|
|
| if cfg.DATASET == "HKO" or (cfg.DATASET == "MOVINGMNIST" and length == 20): |
| |
| |
| |
| if length > 11: |
| k[0] = 4 |
| s[0] = 2 |
| p[0] = 1 |
|
|
| |
| if cfg.DATASET == "HKO": |
| k[1:] = [7, 7] |
| s[1:] = [5, 5] |
| p[1:] = [1, 1] |
|
|
| data = conv2d_3d_act( |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| data=data, |
| name='encode_net_0' + postfix, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER, |
| no_bias=no_bias) |
|
|
| |
| k[1:] = [4, 4] |
| s[1:] = [2, 2] |
| p[1:] = [1, 1] |
|
|
| |
| |
| if length >= 10: |
| k[0] = 4 |
| s[0] = 2 |
| p[0] = 1 |
|
|
| |
| |
| if cfg.DATASET == "HKO": |
| s[1:] = [3, 3] |
|
|
| e1 = conv2d_3d_act( |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| data=data, |
| name='encode_net_1' + postfix, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER, |
| no_bias=no_bias) |
|
|
| |
| |
| if length >= 5: |
| k[0] = 4 |
| s[0] = 2 |
| p[0] = 1 |
|
|
| |
| if cfg.DATASET == "HKO": |
| s[1:] = [2, 2] |
|
|
| e2 = conv2d_3d_bn_act( |
| use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS, |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| use_bn=cfg.MODEL.DECONVBASELINE.BN, |
| data=e1, |
| name='encode_net_2' + postfix, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 2, |
| no_bias=no_bias, |
| height=16, |
| width=16, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| momentum=momentum) |
|
|
| e3 = conv2d_3d_bn_act( |
| use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS, |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| use_bn=cfg.MODEL.DECONVBASELINE.BN, |
| data=e2, |
| name='encode_net_3' + postfix, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 3, |
| no_bias=no_bias, |
| height=8, |
| width=8, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| momentum=momentum) |
|
|
| |
| p[0] = 2 |
|
|
| e4 = conv2d_3d_bn_act( |
| use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS, |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| use_bn=cfg.MODEL.DECONVBASELINE.BN, |
| data=e3, |
| name='encode_net_4' + postfix, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 4, |
| no_bias=no_bias, |
| height=4, |
| width=4, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| momentum=momentum) |
|
|
| |
| |
| |
| |
|
|
| return e4 |
|
|
|
|
| def video_net_symbol(encode_net, |
| no_bias=False, |
| momentum=0.9, |
| fix_gamma=False, |
| eps=1e-5 + 1e-12): |
| if cfg.DATASET == "MOVINGMNIST": |
| OUT_LEN = cfg.MOVINGMNIST.OUT_LEN |
| elif cfg.DATASET == "HKO": |
| OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN |
|
|
| |
| |
| |
| |
|
|
| assert (OUT_LEN in [1, 10, 20]) |
|
|
| k = [1, 1, 1] |
| s = [1, 1, 1] |
| p = [0, 0, 0] |
|
|
| if OUT_LEN > 1: |
| k[0] = 2 |
|
|
| d1 = deconv2d_3d_act( |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| data=encode_net, |
| name='video_net_d1', |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 8, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| no_bias=no_bias) |
|
|
| k[1:] = [4, 4] |
| s[1:] = [2, 2] |
| p[1:] = [1, 1] |
|
|
| if OUT_LEN >= 10: |
| k[0] = 4 |
| s[0] = 2 |
| p[0] = 1 |
|
|
| d2 = deconv2d_3d_bn_act( |
| use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS, |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| use_bn=cfg.MODEL.DECONVBASELINE.BN, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| data=d1, |
| name='video_net_d2', |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 4, |
| no_bias=no_bias, |
| height=8, |
| width=8, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| momentum=momentum) |
|
|
| if OUT_LEN == 10: |
| p[0] = 2 |
| elif OUT_LEN == 20: |
| p[0] = 0 |
|
|
| d3 = deconv2d_3d_bn_act( |
| use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS, |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| use_bn=cfg.MODEL.DECONVBASELINE.BN, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| data=d2, |
| name='video_net_d3', |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER * 2, |
| no_bias=no_bias, |
| height=16, |
| width=16, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| momentum=momentum) |
|
|
| if OUT_LEN == 20: |
| p[0] = 1 |
|
|
| d4 = deconv2d_3d_bn_act( |
| use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS, |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| use_bn=cfg.MODEL.DECONVBASELINE.BN, |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| data=d3, |
| name='video_net_d4', |
| kernel=k, |
| stride=s, |
| pad=p, |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER, |
| no_bias=no_bias, |
| height=32, |
| width=32, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| momentum=momentum) |
|
|
| out_filter = 1 |
| if OUT_LEN > 1: |
| k[0] = 3 |
| s[0] = 1 |
| p[0] = 1 |
|
|
| |
| if cfg.DATASET == "HKO": |
| k[1:] = [5, 5] |
| s[1:] = [3, 3] |
| p[1:] = [1, 1] |
| out_filter = 8 |
|
|
| if cfg.MODEL.DECONVBASELINE.USE_3D: |
| gen_out = mx.sym.Deconvolution( |
| data=d4, |
| name='gen_out', |
| kernel=k, |
| stride=s, |
| pad=p, |
| |
| num_filter=out_filter, |
| no_bias=no_bias) |
| else: |
| gen_out = mx.sym.Deconvolution( |
| data=d4, |
| name='gen_out', |
| kernel=k[1:], |
| stride=s[1:], |
| pad=p[1:], |
| |
| num_filter=OUT_LEN * out_filter, |
| no_bias=no_bias) |
|
|
| |
| if cfg.DATASET == "HKO": |
| k[1:] = [7, 7] |
| s[1:] = [5, 5] |
| p[1:] = [1, 1] |
|
|
| if cfg.MODEL.DECONVBASELINE.USE_3D: |
| gen_out = mx.sym.Deconvolution( |
| data=gen_out, |
| name='gen_out_scale', |
| kernel=k, |
| stride=s, |
| pad=p, |
| |
| num_filter=1 * out_filter, |
| no_bias=no_bias) |
|
|
| else: |
| gen_out = mx.sym.Deconvolution( |
| data=gen_out, |
| name='gen_out_scale', |
| kernel=k[1:], |
| stride=s[1:], |
| pad=p[1:], |
| |
| num_filter=OUT_LEN * out_filter, |
| no_bias=no_bias) |
|
|
| |
| if cfg.DATASET == "HKO": |
| k[1:] = [3, 3] |
| s[1:] = [1, 1] |
| p[1:] = [1, 1] |
|
|
| if cfg.MODEL.DECONVBASELINE.USE_3D: |
| gen_out = mx.sym.Deconvolution( |
| data=gen_out, |
| name='gen_out_scale2', |
| kernel=k, |
| stride=s, |
| pad=p, |
| |
| num_filter=1, |
| no_bias=no_bias) |
|
|
| else: |
| gen_out = mx.sym.Deconvolution( |
| data=gen_out, |
| name='gen_out_scale2', |
| kernel=k[1:], |
| stride=s[1:], |
| pad=p[1:], |
| |
| num_filter=OUT_LEN, |
| no_bias=no_bias) |
|
|
| |
| |
|
|
| return gen_out |
|
|
|
|
| def generator_symbol(context, |
| no_bias=False, |
| momentum=0.9, |
| fix_gamma=False, |
| eps=1e-5 + 1e-12): |
|
|
| encode_net = encode_net_symbol( |
| data=context, |
| data_type="context", |
| no_bias=no_bias, |
| momentum=momentum, |
| fix_gamma=fix_gamma, |
| eps=eps) |
|
|
| if cfg.MODEL.DECONVBASELINE.FC_BETWEEN_ENCDEC: |
| encode_net = mx.sym.FullyConnected( |
| data=encode_net, |
| num_hidden=cfg.MODEL.DECONVBASELINE.FC_BETWEEN_ENCDEC) |
|
|
| if cfg.MODEL.DECONVBASELINE.USE_3D: |
| encode_net = mx.sym.Reshape( |
| data=encode_net, shape=(cfg.MODEL.TRAIN.BATCH_SIZE, -1, 1, 4, 4)) |
| else: |
| encode_net = mx.sym.Reshape( |
| data=encode_net, shape=(cfg.MODEL.TRAIN.BATCH_SIZE, -1, 4, 4)) |
|
|
| gen_net = video_net_symbol( |
| encode_net, |
| no_bias=no_bias, |
| momentum=momentum, |
| fix_gamma=fix_gamma, |
| eps=eps) |
|
|
| if cfg.DATASET == "MOVINGMNIST": |
| OUT_LEN = cfg.MOVINGMNIST.OUT_LEN |
| IMG_SIZE = cfg.MOVINGMNIST.IMG_SIZE |
| elif cfg.DATASET == "HKO": |
| OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN |
| IMG_SIZE = cfg.HKO.ITERATOR.WIDTH |
|
|
| |
| gen_net = mx.sym.reshape( |
| gen_net, |
| shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, OUT_LEN, IMG_SIZE, IMG_SIZE), |
| name="pred") |
|
|
| return mx.sym.Group([ |
| gen_net, |
| mx.sym.BlockGrad( |
| mx.sym.clip(gen_net, a_min=0, a_max=1), name="forecast_target") |
| ]) |
|
|
|
|
| def discriminator_symbol(context, |
| pred, |
| no_bias=False, |
| momentum=0.9, |
| fix_gamma=False, |
| eps=1e-5 + 1e-12): |
| |
| |
| if cfg.DATASET == "MOVINGMNIST": |
| OUT_LEN = cfg.MOVINGMNIST.OUT_LEN |
| elif cfg.DATASET == "HKO": |
| OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN |
| mask = mx.sym.Variable('mask') |
| pred = pred * mask |
|
|
| if cfg.MODEL.DECONVBASELINE.ENCODER in ["shared", "separate"]: |
| postfix = "" if cfg.MODEL.DECONVBASELINE.ENCODER == "shared" else "_gan" |
|
|
| context_encoding = encode_net_symbol( |
| data=context, |
| data_type="context", |
| no_bias=no_bias, |
| momentum=momentum, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| postfix=postfix) |
|
|
| pred_encoding = encode_net_symbol( |
| data=pred, |
| data_type="pred", |
| no_bias=no_bias, |
| momentum=momentum, |
| fix_gamma=fix_gamma, |
| eps=eps) |
|
|
| |
| |
|
|
| if cfg.MODEL.DECONVBASELINE.USE_3D: |
| context_pred = mx.sym.concat( |
| context_encoding, pred_encoding, dim=2) |
| else: |
| context_pred = mx.sym.concat( |
| context_encoding, pred_encoding, dim=1) |
|
|
| |
| if cfg.MODEL.DECONVBASELINE.COMPAT.CONV_INSTEADOF_FC_IN_ENCODER: |
| |
| d5 = conv2d_3d_bn_act( |
| use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS, |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| use_bn=cfg.MODEL.DECONVBASELINE.BN, |
| data=context_pred, |
| name='discriminator_5', |
| act_type=cfg.MODEL.CNN_ACT_TYPE, |
| kernel=(1, 1, 1), |
| stride=(1, 1, 1), |
| pad=(0, 0, 0), |
| num_filter=cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER, |
| no_bias=no_bias, |
| height=4, |
| width=4, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| momentum=momentum) |
|
|
| d6 = conv2d_3d( |
| use_3d=cfg.MODEL.DECONVBASELINE.USE_3D, |
| data=d5, |
| name='discriminator_6', |
| kernel=(1, 4, 4), |
| stride=(1, 1, 1), |
| pad=(0, 0, 0), |
| num_filter=1, |
| no_bias=no_bias) |
| return mx.sym.Flatten(d6) |
| else: |
| |
| flattened_encoding = mx.sym.Flatten(data=context_pred) |
|
|
| elif cfg.MODEL.DECONVBASELINE.ENCODER == "concat": |
| context_pred = mx.sym.concat(context, pred, dim=2) |
|
|
| encoding = encode_net_symbol( |
| data=context_pred, |
| data_type="contextpred", |
| no_bias=no_bias, |
| momentum=momentum, |
| fix_gamma=fix_gamma, |
| eps=eps) |
| flattened_encoding = mx.sym.Flatten(data=encoding) |
|
|
| else: |
| raise NotImplementedError |
|
|
| fc1 = fc_layer_act( |
| data=flattened_encoding, |
| num_hidden=256, |
| name="discriminator_fc_1", |
| act_type=cfg.MODEL.CNN_ACT_TYPE) |
| return fc_layer(data=fc1, num_hidden=1, name="discriminator_fc_2") |
|
|
|
|
| |
| def batchnorm_5d(data, height, width, name, fix_gamma, eps, momentum): |
| data = mx.symbol.reshape(data, shape=(0, 0, -1, width)) |
|
|
| data = mx.sym.BatchNorm( |
| data, |
| name=name, |
| fix_gamma=fix_gamma, |
| eps=eps, |
| momentum=momentum, |
| use_global_stats=cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS) |
|
|
| return mx.symbol.reshape(data, shape=(0, 0, -1, height, width)) |
|
|
|
|
| def conv2d_3d(data, |
| num_filter, |
| kernel=(1, 1, 1), |
| stride=(1, 1, 1), |
| pad=(0, 0, 0), |
| dilate=(1, 1, 1), |
| no_bias=False, |
| name=None, |
| use_3d=True, |
| **kwargs): |
| """If use_3d == False use a 2D convolution with the same number of parameters.""" |
| if use_3d: |
| return conv3d( |
| data=data, |
| num_filter=num_filter, |
| kernel=kernel, |
| stride=stride, |
| pad=pad, |
| dilate=dilate, |
| no_bias=no_bias, |
| name=name, |
| **kwargs) |
| else: |
| return conv2d( |
| data=data, |
| num_filter=num_filter * kernel[0], |
| kernel=kernel[1:], |
| stride=stride[1:], |
| pad=pad[1:], |
| dilate=dilate[1:], |
| no_bias=no_bias, |
| name=name, |
| **kwargs) |
|
|
|
|
| def conv2d_3d_bn_act(data, |
| num_filter, |
| height, |
| width, |
| kernel=(1, 1, 1), |
| stride=(1, 1, 1), |
| pad=(0, 0, 0), |
| dilate=(1, 1, 1), |
| no_bias=False, |
| act_type="relu", |
| momentum=0.9, |
| eps=1e-5 + 1e-12, |
| fix_gamma=True, |
| name=None, |
| use_3d=True, |
| use_bn=True, |
| use_global_stats=False, |
| **kwargs): |
| """If use_3d == False use a 2D convolution with the same number of parameters.""" |
| if not use_bn: |
| return conv2d_3d_act( |
| data=data, |
| num_filter=num_filter, |
| kernel=kernel, |
| stride=stride, |
| pad=pad, |
| dilate=dilate, |
| no_bias=no_bias, |
| act_type=act_type, |
| name=name, |
| use_3d=use_3d) |
|
|
| if use_3d: |
| return conv3d_bn_act( |
| data=data, |
| num_filter=num_filter, |
| height=height, |
| width=width, |
| kernel=kernel, |
| stride=stride, |
| pad=pad, |
| dilate=dilate, |
| no_bias=no_bias, |
| act_type=act_type, |
| momentum=momentum, |
| eps=eps, |
| fix_gamma=fix_gamma, |
| name=name, |
| use_global_stats=use_global_stats, |
| **kwargs) |
| else: |
| return conv2d_bn_act( |
| data=data, |
| num_filter=num_filter * kernel[0], |
| kernel=kernel[1:], |
| stride=stride[1:], |
| pad=pad[1:], |
| dilate=dilate[1:], |
| no_bias=no_bias, |
| act_type=act_type, |
| momentum=momentum, |
| eps=eps, |
| fix_gamma=fix_gamma, |
| name=name, |
| use_global_stats=use_global_stats, |
| **kwargs) |
|
|
|
|
| def conv2d_3d_act(data, |
| num_filter, |
| kernel=(1, 1, 1), |
| stride=(1, 1, 1), |
| pad=(0, 0, 0), |
| dilate=(1, 1, 1), |
| no_bias=False, |
| act_type="relu", |
| name=None, |
| use_3d=True, |
| **kwargs): |
| """If use_3d == False use a 2D convolution with the same number of parameters.""" |
| if use_3d: |
| return conv3d_act( |
| data=data, |
| num_filter=num_filter, |
| kernel=kernel, |
| stride=stride, |
| pad=pad, |
| dilate=dilate, |
| no_bias=no_bias, |
| act_type=act_type, |
| name=name, |
| **kwargs) |
| else: |
| return conv2d_act( |
| data=data, |
| num_filter=num_filter * kernel[0], |
| kernel=kernel[1:], |
| stride=stride[1:], |
| pad=pad[1:], |
| dilate=dilate[1:], |
| no_bias=no_bias, |
| act_type=act_type, |
| name=name, |
| **kwargs) |
|
|
|
|
| def deconv2d_3d(data, |
| num_filter, |
| kernel=(1, 1, 1), |
| stride=(1, 1, 1), |
| pad=(0, 0, 0), |
| adj=(0, 0, 0), |
| no_bias=True, |
| target_shape=None, |
| name=None, |
| use_3d=True, |
| **kwargs): |
| """If use_3d == False use a 2D deconvolution with the same number of parameters.""" |
| if use_3d: |
| return deconv3d_act( |
| data=data, |
| num_filter=num_filter, |
| kernel=kernel, |
| stride=stride, |
| pad=pad, |
| adj=adj, |
| no_bias=no_bias, |
| target_shape=target_shape, |
| act_type=act_type, |
| name=name, |
| **kwargs) |
| else: |
| return deconv2d_act( |
| data=data, |
| num_filter=num_filter * kernel[0], |
| kernel=kernel[1:], |
| stride=stride[1:], |
| pad=pad[1:], |
| adj=adj[1:], |
| no_bias=no_bias, |
| target_shape=target_shape, |
| act_type=act_type, |
| name=name, |
| **kwargs) |
|
|
|
|
| def deconv2d_3d_bn_act(data, |
| num_filter, |
| height, |
| width, |
| kernel=(1, 1, 1), |
| stride=(1, 1, 1), |
| pad=(0, 0, 0), |
| adj=(0, 0, 0), |
| no_bias=True, |
| target_shape=None, |
| act_type="relu", |
| momentum=0.9, |
| eps=1e-5 + 1e-12, |
| fix_gamma=True, |
| name=None, |
| use_3d=True, |
| use_bn=True, |
| use_global_stats=False, |
| **kwargs): |
| """If use_3d == False use a 2D deconvolution with the same number of parameters.""" |
| if not use_bn: |
| return deconv2d_3d_act( |
| data=data, |
| num_filter=num_filter, |
| kernel=kernel, |
| stride=stride, |
| pad=pad, |
| adj=adj, |
| no_bias=no_bias, |
| act_type=act_type, |
| name=name, |
| use_3d=use_3d, ) |
|
|
| if use_3d: |
| return deconv3d_bn_act( |
| data=data, |
| num_filter=num_filter, |
| height=height, |
| width=width, |
| kernel=kernel, |
| stride=stride, |
| pad=pad, |
| adj=adj, |
| no_bias=no_bias, |
| target_shape=target_shape, |
| act_type=act_type, |
| momentum=momentum, |
| eps=eps, |
| fix_gamma=fix_gamma, |
| name=name, |
| use_global_stats=use_global_stats, |
| **kwargs) |
| else: |
| return deconv2d_bn_act( |
| data=data, |
| num_filter=num_filter * kernel[0], |
| kernel=kernel[1:], |
| stride=stride[1:], |
| pad=pad[1:], |
| adj=adj[1:], |
| no_bias=no_bias, |
| target_shape=target_shape, |
| act_type=act_type, |
| momentum=momentum, |
| eps=eps, |
| fix_gamma=fix_gamma, |
| name=name, |
| use_global_stats=use_global_stats, |
| **kwargs) |
|
|
|
|
| def deconv2d_3d_act(data, |
| num_filter, |
| kernel=(1, 1, 1), |
| stride=(1, 1, 1), |
| pad=(0, 0, 0), |
| adj=(0, 0, 0), |
| no_bias=True, |
| target_shape=None, |
| act_type="relu", |
| name=None, |
| use_3d=True, |
| **kwargs): |
| """If use_3d == False use a 2D deconvolution with the same number of parameters.""" |
| if use_3d: |
| return deconv3d_act( |
| data=data, |
| num_filter=num_filter, |
| kernel=kernel, |
| stride=stride, |
| pad=pad, |
| adj=adj, |
| no_bias=no_bias, |
| target_shape=target_shape, |
| act_type=act_type, |
| name=name, |
| **kwargs) |
| else: |
| return deconv2d_act( |
| data=data, |
| num_filter=num_filter * kernel[0], |
| kernel=kernel[1:], |
| stride=stride[1:], |
| pad=pad[1:], |
| adj=adj[1:], |
| no_bias=no_bias, |
| target_shape=target_shape, |
| act_type=act_type, |
| name=name, |
| **kwargs) |
|
|