File size: 4,303 Bytes
45ac12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import dataclasses
import logging
from pathlib import Path

import torch
import tyro
from safetensors.torch import load_file, save_file


@dataclasses.dataclass
class Args:
    # Base pretrained weights in safetensors
    base_safetensors: str

    # Diff checkpoint in .pth (either {"state_dict": ...} or raw state_dict)
    diff_pth: str

    # Output safetensors path
    out_safetensors: str = "model_merged.safetensors"

    # final = base + scale * diff
    scale: float = 1.0

    # whether keys must match exactly
    strict_keys: bool = True  # use --strict-keys / --no-strict-keys

    # arithmetic dtype
    dtype: str = "fp32"  # fp32/fp16/bf16

    # compute device
    device: str = "cpu"  # cpu/cuda


def cast(t: torch.Tensor, dtype: str) -> torch.Tensor:
    if dtype == "fp32":
        return t.float()
    if dtype == "fp16":
        return t.half()
    if dtype == "bf16":
        return t.bfloat16()
    raise ValueError(f"Unknown dtype: {dtype}")


def load_diff_state_dict(path: str) -> dict[str, torch.Tensor]:
    obj = torch.load(path, map_location="cpu")
    if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict):
        sd = obj["state_dict"]
    elif isinstance(obj, dict):
        sd = obj
    else:
        raise RuntimeError(f"Unexpected diff format: {type(obj)}")

    for k, v in sd.items():
        if not isinstance(v, torch.Tensor):
            raise RuntimeError(f"Diff contains non-tensor at key={k}: {type(v)}")
    return sd


def main(args: Args) -> None:
    logging.info("Loading base safetensors: %s", args.base_safetensors)
    base_sd = load_file(args.base_safetensors, device="cpu")  # dict[str, Tensor]

    logging.info("Loading diff pth: %s", args.diff_pth)
    diff_sd = load_diff_state_dict(args.diff_pth)

    keys_base = set(base_sd.keys())
    keys_diff = set(diff_sd.keys())

    if args.strict_keys:
        if keys_base != keys_diff:
            only_base = sorted(list(keys_base - keys_diff))[:30]
            only_diff = sorted(list(keys_diff - keys_base))[:30]
            raise RuntimeError(
                "Keys mismatch between base safetensors and diff.\n"
                f"Only in base (up to 30): {only_base}\n"
                f"Only in diff (up to 30): {only_diff}\n"
                "Use --no-strict-keys to apply on intersection only."
            )
        keys_apply = keys_base
    else:
        keys_apply = keys_base & keys_diff
        logging.warning("Non-strict mode: applying on intersection keys: %d", len(keys_apply))

    dev = torch.device(args.device)

    merged_sd: dict[str, torch.Tensor] = {}
    applied_float = 0
    skipped_nonfloat = 0
    skipped_missing = 0

    for k, base_t_cpu in base_sd.items():
        base_t = base_t_cpu  # already on cpu

        if k not in keys_apply:
            merged_sd[k] = base_t
            skipped_missing += 1
            continue

        diff_t_cpu = diff_sd[k]

        if base_t.shape != diff_t_cpu.shape:
            raise RuntimeError(f"Shape mismatch at key={k}: base {base_t.shape} vs diff {diff_t_cpu.shape}")

        # only add for floating-point tensors
        if base_t.is_floating_point() and diff_t_cpu.is_floating_point():
            a = cast(base_t.to(dev), args.dtype)
            d = cast(diff_t_cpu.to(dev), args.dtype)
            out = a + args.scale * d
            merged_sd[k] = out.to(base_t.dtype).detach().cpu()
            applied_float += 1
        else:
            merged_sd[k] = base_t
            skipped_nonfloat += 1

    out_path = Path(args.out_safetensors)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    # safetensors 需要所有 tensor 在 CPU
    for k, v in merged_sd.items():
        if v.device.type != "cpu":
            merged_sd[k] = v.cpu()

    logging.info(
        "Done. applied_float=%d, skipped_nonfloat=%d, skipped_missing=%d",
        applied_float,
        skipped_nonfloat,
        skipped_missing,
    )
    logging.info("Saving merged safetensors to: %s", str(out_path))
    save_file(merged_sd, str(out_path))


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, force=True)
    main(tyro.cli(Args))