| import json |
|
|
| DEFAULTS = { |
| "network": "dpn", |
| "encoder": "dpn92", |
| "model_params": {}, |
| "optimizer": { |
| "batch_size": 32, |
| "type": "SGD", |
| "momentum": 0.9, |
| "weight_decay": 0, |
| "clip": 1., |
| "learning_rate": 0.1, |
| "classifier_lr": -1, |
| "nesterov": True, |
| "schedule": { |
| "type": "constant", |
| "mode": "epoch", |
| "epochs": 10, |
| "params": {} |
| } |
| }, |
| "normalize": { |
| "mean": [0.485, 0.456, 0.406], |
| "std": [0.229, 0.224, 0.225] |
| } |
| } |
|
|
|
|
| def _merge(src, dst): |
| for k, v in src.items(): |
| if k in dst: |
| if isinstance(v, dict): |
| _merge(src[k], dst[k]) |
| else: |
| dst[k] = v |
|
|
|
|
| def load_config(config_file, defaults=DEFAULTS): |
| with open(config_file, "r") as fd: |
| config = json.load(fd) |
| _merge(defaults, config) |
| return config |
|
|