| import torch |
| from configuration_neuroclr import NeuroCLRConfig |
| from modeling_neuroclr import NeuroCLRModel |
|
|
| |
| CFG = dict( |
| TSlength=128, |
| nhead=2, |
| nlayer=2, |
| projector_out1=128, |
| projector_out2=64, |
| pooling="flatten", |
| normalize_input=True, |
| ) |
| CKPT_PATH = "" |
| OUT_DIR = "." |
| |
|
|
| def remap_state_dict(sd): |
| new_sd = {} |
| for k, v in sd.items(): |
| k2 = k.replace("module.", "") |
| if k2.startswith("transformer_encoder.") or k2.startswith("projector."): |
| new_sd["neuroclr." + k2] = v |
| else: |
| |
| new_sd[k2] = v |
| return new_sd |
|
|
| def main(): |
| config = NeuroCLRConfig(**CFG) |
|
|
| |
| config.auto_map = { |
| "AutoConfig": "configuration_neuroclr.NeuroCLRConfig", |
| "AutoModel": "modeling_neuroclr.NeuroCLRModel", |
| } |
|
|
| model = NeuroCLRModel(config) |
|
|
| ckpt = torch.load(CKPT_PATH, map_location="cpu") |
|
|
| |
| if isinstance(ckpt, dict) and "model_state_dict" in ckpt: |
| sd = ckpt["model_state_dict"] |
| elif isinstance(ckpt, dict) and "state_dict" in ckpt: |
| sd = ckpt["state_dict"] |
| else: |
| sd = ckpt |
|
|
| sd = remap_state_dict(sd) |
|
|
| missing, unexpected = model.load_state_dict(sd, strict=False) |
| print("Missing:", missing) |
| print("Unexpected:", unexpected) |
|
|
| model.save_pretrained(OUT_DIR, safe_serialization=True) |
| print("Saved HF pretraining model to:", OUT_DIR) |
|
|
| if __name__ == "__main__": |
| main() |
|
|