paultltc commited on
Commit
9052fc0
·
verified ·
1 Parent(s): 3b02d5a

Remove adapter migration section and script

Browse files
Files changed (1) hide show
  1. convert_adapter_weights.py +0 -107
convert_adapter_weights.py DELETED
@@ -1,107 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import json
5
- import shutil
6
- from pathlib import Path
7
-
8
- from safetensors.torch import load_file, save_file
9
-
10
-
11
- TEXT_PREFIX_IN = "base_model.model.model.text_model."
12
- TEXT_PREFIX_OUT = "model.text_model."
13
- PROJ_PREFIX_IN = "base_model.model.custom_text_proj."
14
- PROJ_PREFIX_OUT = "model.custom_text_proj."
15
-
16
-
17
- def parse_args() -> argparse.Namespace:
18
- parser = argparse.ArgumentParser(
19
- description="Convert the legacy paultltc/colmodernvbert LoRA adapter into the official ModernVBERT/colmodernvbert format."
20
- )
21
- parser.add_argument("input_dir", type=Path, help="Legacy adapter directory")
22
- parser.add_argument("output_dir", type=Path, help="Converted adapter directory")
23
- return parser.parse_args()
24
-
25
-
26
- def ensure_new_dir(path: Path) -> None:
27
- if path.exists():
28
- raise FileExistsError(f"{path} already exists; refusing to overwrite it")
29
- path.mkdir(parents=True)
30
-
31
-
32
- def copy_support_files(src: Path, dst: Path) -> None:
33
- excluded = {"adapter_model.safetensors", "adapter_config.json", "BUILD_INFO.json"}
34
- for item in src.iterdir():
35
- if item.name in excluded:
36
- continue
37
- target = dst / item.name
38
- if item.is_dir():
39
- shutil.copytree(item, target)
40
- else:
41
- shutil.copy2(item, target)
42
-
43
-
44
- def convert_adapter_weights(src_dir: Path, dst_dir: Path) -> dict[str, int]:
45
- src_weights = load_file(str(src_dir / "adapter_model.safetensors"))
46
- out_weights = {}
47
- converted_text = 0
48
- converted_proj = 0
49
- untouched = []
50
-
51
- for key, value in src_weights.items():
52
- if key.startswith(TEXT_PREFIX_IN):
53
- out_key = TEXT_PREFIX_OUT + key[len(TEXT_PREFIX_IN) :]
54
- out_weights[out_key] = value
55
- converted_text += 1
56
- elif key.startswith(PROJ_PREFIX_IN):
57
- out_key = PROJ_PREFIX_OUT + key[len(PROJ_PREFIX_IN) :]
58
- out_weights[out_key] = value
59
- converted_proj += 1
60
- else:
61
- untouched.append(key)
62
-
63
- if untouched:
64
- sample = ", ".join(sorted(untouched)[:5])
65
- raise ValueError(f"Found unexpected legacy adapter keys that were not converted: {sample}")
66
-
67
- save_file(out_weights, str(dst_dir / "adapter_model.safetensors"))
68
- return {
69
- "source_tensor_count": len(src_weights),
70
- "output_tensor_count": len(out_weights),
71
- "converted_text_tensors": converted_text,
72
- "converted_projection_tensors": converted_proj,
73
- }
74
-
75
-
76
- def convert_adapter_config(src_dir: Path, dst_dir: Path) -> None:
77
- config = json.loads((src_dir / "adapter_config.json").read_text())
78
- config["target_modules"] = "(.*(model.text_model).*(Wo|Wqkv|Wi).*$|.*(custom_text_proj).*$)"
79
- (dst_dir / "adapter_config.json").write_text(json.dumps(config, indent=2) + "\n")
80
-
81
-
82
- def main() -> None:
83
- args = parse_args()
84
- ensure_new_dir(args.output_dir)
85
- copy_support_files(args.input_dir, args.output_dir)
86
-
87
- weight_info = convert_adapter_weights(args.input_dir, args.output_dir)
88
- convert_adapter_config(args.input_dir, args.output_dir)
89
-
90
- build_info = {
91
- "description": "Legacy paultltc/colmodernvbert adapter converted to the official ModernVBERT/colmodernvbert key layout.",
92
- "input_dir": str(args.input_dir),
93
- "output_dir": str(args.output_dir),
94
- **weight_info,
95
- "key_mapping": {
96
- TEXT_PREFIX_IN: TEXT_PREFIX_OUT,
97
- PROJ_PREFIX_IN: PROJ_PREFIX_OUT,
98
- },
99
- }
100
- (args.output_dir / "BUILD_INFO.json").write_text(json.dumps(build_info, indent=2) + "\n")
101
-
102
- print(f"Wrote {args.output_dir}")
103
- print(f"Converted {weight_info['output_tensor_count']} adapter tensors")
104
-
105
-
106
- if __name__ == "__main__":
107
- main()