| |
|
|
| """Extract the model backbone from the checkpoint.""" |
|
|
| import torch |
|
|
| from torchgeo.models import dofa_base_patch16_224 |
|
|
| |
| in_filename = "DOFA_ViT_base_e100.pth" |
| weights = torch.load(in_filename, map_location=torch.device("cpu")) |
|
|
| |
| del weights["mask_token"] |
| del weights["norm.weight"], weights["norm.bias"] |
| del weights["projector.weight"], weights["projector.bias"] |
|
|
| |
| |
| allowed_missing_keys = {"fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"} |
| model = dofa_base_patch16_224() |
| missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) |
| assert set(missing_keys) <= allowed_missing_keys |
| assert not unexpected_keys |
|
|
| |
| |
| out_filename = "dofa_base_patch16_224.pth" |
| torch.save(weights, out_filename) |
|
|