| import torch |
| from configuration_neuroclr import NeuroCLRConfig |
| from modeling_neuroclr import NeuroCLRForSequenceClassification |
|
|
| |
| PRETRAIN_CKPT = "" |
| HEAD_CKPT = "" |
| OUT_DIR = "." |
|
|
| CFG = dict( |
| |
| TSlength=128, |
| nhead=2, |
| nlayer=2, |
| projector_out1=128, |
| projector_out2=64, |
| pooling="flatten", |
| normalize_input=True, |
|
|
| |
| n_rois=200, |
| num_labels=2, |
| freeze_encoder=True, |
|
|
| |
| base_filters=256, |
| kernel_size=16, |
| stride=2, |
| groups=32, |
| n_block=48, |
| downsample_gap=6, |
| increasefilter_gap=12, |
| use_bn=True, |
| use_do=True, |
| ) |
| |
|
|
| def load_model_state_dict(path): |
| ckpt = torch.load(path, map_location="cpu") |
| if isinstance(ckpt, dict): |
| if "model_state_dict" in ckpt: |
| return ckpt["model_state_dict"] |
| if "state_dict" in ckpt: |
| return ckpt["state_dict"] |
| return ckpt |
| return ckpt |
|
|
| def remap_encoder(sd): |
| |
| new = {} |
| for k, v in sd.items(): |
| k2 = k.replace("module.", "") |
| if k2.startswith("transformer_encoder.") or k2.startswith("projector."): |
| new["encoder." + k2] = v |
| return new |
|
|
| def remap_head(sd): |
| |
| new = {} |
| for k, v in sd.items(): |
| k2 = k.replace("module.", "") |
|
|
| head_prefixes = ( |
| "first_block_conv.", "first_block_bn.", "first_block_relu.", |
| "basicblock_list.", "final_bn.", "final_relu.", "dense." |
| ) |
| if k2.startswith(head_prefixes): |
| new["head." + k2] = v |
|
|
| |
| elif k2.startswith("head."): |
| new[k2] = v |
|
|
| return new |
|
|
| def main(): |
| config = NeuroCLRConfig(**CFG) |
|
|
| |
| config.auto_map = { |
| "AutoConfig": "configuration_neuroclr.NeuroCLRConfig", |
| "AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification", |
| } |
|
|
| model = NeuroCLRForSequenceClassification(config) |
|
|
| |
| enc_sd_raw = load_model_state_dict(PRETRAIN_CKPT) |
| enc_sd = remap_encoder(enc_sd_raw) |
|
|
| |
| head_sd_raw = load_model_state_dict(HEAD_CKPT) |
| head_sd = remap_head(head_sd_raw) |
|
|
| |
| merged = {} |
| merged.update(enc_sd) |
| merged.update(head_sd) |
|
|
| missing, unexpected = model.load_state_dict(merged, strict=False) |
| print("Missing:", missing) |
| print("Unexpected:", unexpected) |
|
|
| |
| model.save_pretrained(OUT_DIR, safe_serialization=True) |
| print("Saved HF classification model to:", OUT_DIR) |
|
|
| if __name__ == "__main__": |
| main() |
|
|