from __future__ import annotations import argparse import json import shutil from pathlib import Path import torch from safetensors.torch import load_file, save_file TEXT_EMBED_KEY = "model.text_model.embeddings.tok_embeddings.weight" TEXT_EXTRA_EMBED_KEY = "model.text_model.embeddings.tok_embeddings.additional_embedding.weight" CONNECTOR_IN = "model.connector.modality_projection.proj.weight" CONNECTOR_OUT = "model.connector.modality_projection.weight" VISION_PREFIX_IN = "model.vision_model." VISION_PREFIX_OUT = "model.vision_model.vision_model." def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Convert a legacy ModernVBERT checkpoint into the official transformers format." ) parser.add_argument("input_dir", type=Path, help="Legacy model directory") parser.add_argument("output_dir", type=Path, help="Converted model directory") parser.add_argument( "--config-template", type=Path, default=Path(__file__).resolve().with_name("config.json"), help="Transformers-format config.json template to copy into the converted output", ) return parser.parse_args() def ensure_new_dir(path: Path) -> None: if path.exists(): raise FileExistsError(f"{path} already exists; refusing to overwrite it") path.mkdir(parents=True) def copy_support_files(src: Path, dst: Path) -> None: excluded = {"model.safetensors", "config.json", "BUILD_INFO.json"} for item in src.iterdir(): if item.name in excluded: continue target = dst / item.name if item.is_dir(): shutil.copytree(item, target) else: shutil.copy2(item, target) def convert_model_weights(src_dir: Path, dst_dir: Path) -> dict[str, int]: src_weights = load_file(str(src_dir / "model.safetensors")) out_weights = {} merged_embeddings = 0 renamed_connector = 0 renamed_vision = 0 for key, value in src_weights.items(): if key == TEXT_EXTRA_EMBED_KEY: continue if key == TEXT_EMBED_KEY and TEXT_EXTRA_EMBED_KEY in src_weights: value = torch.cat([value, src_weights[TEXT_EXTRA_EMBED_KEY]], dim=0) merged_embeddings += 1 if key == CONNECTOR_IN: key = CONNECTOR_OUT renamed_connector += 1 elif key.startswith(VISION_PREFIX_IN) and not key.startswith(VISION_PREFIX_OUT): key = VISION_PREFIX_OUT + key[len(VISION_PREFIX_IN) :] renamed_vision += 1 out_weights[key] = value save_file(out_weights, str(dst_dir / "model.safetensors")) return { "source_tensor_count": len(src_weights), "output_tensor_count": len(out_weights), "merged_token_embedding_tables": merged_embeddings, "renamed_connector_tensors": renamed_connector, "renamed_vision_tensors": renamed_vision, } def write_config(template_path: Path, dst_dir: Path) -> dict[str, str]: if not template_path.exists(): raise FileNotFoundError(f"Config template not found: {template_path}") config = json.loads(template_path.read_text()) (dst_dir / "config.json").write_text(json.dumps(config, indent=2) + "\n") return {"config_template": str(template_path)} def main() -> None: args = parse_args() ensure_new_dir(args.output_dir) copy_support_files(args.input_dir, args.output_dir) weight_info = convert_model_weights(args.input_dir, args.output_dir) config_info = write_config(args.config_template, args.output_dir) build_info = { "description": "Legacy ModernVBERT checkpoint converted to the official transformers format.", "input_dir": str(args.input_dir), "output_dir": str(args.output_dir), **weight_info, **config_info, "key_mapping": { TEXT_EXTRA_EMBED_KEY: f"merged into {TEXT_EMBED_KEY}", CONNECTOR_IN: CONNECTOR_OUT, VISION_PREFIX_IN: VISION_PREFIX_OUT, }, } (args.output_dir / "BUILD_INFO.json").write_text(json.dumps(build_info, indent=2) + "\n") print(f"Wrote {args.output_dir}") print(f"Converted {weight_info['output_tensor_count']} model tensors") if __name__ == "__main__": main()