| """ |
| Module for testing the 3D clutter filtering model. |
| """ |
| import os |
| import argparse |
| import json |
| import numpy as np |
| import pandas as pd |
|
|
| from utils import * |
| from Model_ClutterFilter3D import clutter_filter_3D |
| from DataGen import DataGen |
| from Error_analysis import compute_mae |
| |
| def data_generation(in_ids_te, out_ids_te, config): |
| DtaGenTe_prm = { |
| 'dim': config["network_prm"]["input_dim"], |
| 'in_dir': in_ids_te, |
| 'out_dir': out_ids_te, |
| 'id_list': np.arange(len(in_ids_te)), |
| 'batch_size': config["learning_prm"]["batch_size"], |
| 'tr_phase': False} |
| return DataGen(**DtaGenTe_prm) |
|
|
| def main(config): |
| in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config) |
| te_gen = data_generation(in_ids_te, out_ids_te, config) |
| model = clutter_filter_3D(**config) |
| weight_dir = create_weight_dir(val_subject, te_subject, config) |
| model.load_weights( |
| os.path.join(weight_dir, config["weight_name"] + ".hdf5")) |
| results_te = model.predict_generator(te_gen, verbose=2) |
| df_errors = compute_mae(in_ids_te, results_te) |
| df_errors.to_csv( |
| os.path.join(weight_dir, config["weight_name"] + ".csv")) |
| return None |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", help="path of the config file", default="config.json") |
| args = parser.parse_args() |
| assert os.path.isfile(args.config) |
| with open(args.config, "r") as read_file: |
| config = json.load(read_file) |
| main(config) |