import json DEFAULTS = { "network": "dpn", "encoder": "dpn92", "model_params": {}, "optimizer": { "batch_size": 32, "type": "SGD", # supported: SGD, Adam "momentum": 0.9, "weight_decay": 0, "clip": 1., "learning_rate": 0.1, "classifier_lr": -1, "nesterov": True, "schedule": { "type": "constant", # supported: constant, step, multistep, exponential, linear, poly "mode": "epoch", # supported: epoch, step "epochs": 10, "params": {} } }, "normalize": { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225] } } def _merge(src, dst): for k, v in src.items(): if k in dst: if isinstance(v, dict): _merge(src[k], dst[k]) else: dst[k] = v def load_config(config_file, defaults=DEFAULTS): with open(config_file, "r") as fd: config = json.load(fd) _merge(defaults, config) return config