| |
| """ |
| Merge BAGEL EMA checkpoint into a standard inference checkpoint. |
| |
| The repository ships two shards: |
| * ``ema.safetensors`` – EMA weights for the Mixture-of-Transformer stack, |
| connector and ViT encoder described by ``llm_config.json`` / ``vit_config.json``. |
| * ``ae.safetensors`` – VAE weights referenced by ``model.safetensors.index.json``. |
| |
| This script combines the two into a single ``model`` checkpoint that can be used in |
| place of the EMA file. By default the script keeps the source files untouched and |
| writes a new ``model_from_ema.safetensors`` plus, optionally, an accompanying index. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from collections import OrderedDict |
| from pathlib import Path |
| from typing import Dict |
|
|
| import torch |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Convert BAGEL EMA weights into a regular inference checkpoint." |
| ) |
| parser.add_argument( |
| "--ema", |
| type=Path, |
| default=Path("ema.safetensors"), |
| help="Path to the EMA weights file (default: ema.safetensors).", |
| ) |
| parser.add_argument( |
| "--ae", |
| type=Path, |
| default=Path("ae.safetensors"), |
| help="Path to the VAE weights file (default: ae.safetensors).", |
| ) |
| parser.add_argument( |
| "--output", |
| type=Path, |
| default=Path("model_from_ema.safetensors"), |
| help="Destination for the merged checkpoint.", |
| ) |
| parser.add_argument( |
| "--index", |
| type=Path, |
| default=None, |
| help="Optional path for a Hugging Face style index JSON file.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def load_safetensors(path: Path) -> Dict[str, torch.Tensor]: |
| try: |
| from safetensors.torch import load_file |
| except ImportError as exc: |
| raise RuntimeError( |
| "safetensors is required. Install it with `pip install safetensors`." |
| ) from exc |
|
|
| tensors = load_file(str(path)) |
| if not tensors: |
| raise ValueError(f"{path} does not contain any tensors.") |
| return tensors |
|
|
|
|
| def save_safetensors( |
| tensors: Dict[str, torch.Tensor], path: Path, *, metadata: Dict[str, str] |
| ) -> None: |
| try: |
| from safetensors.torch import save_file |
| except ImportError as exc: |
| raise RuntimeError( |
| "safetensors is required. Install it with `pip install safetensors`." |
| ) from exc |
|
|
| save_file(tensors, str(path), metadata=metadata) |
|
|
|
|
| def compute_total_size_bytes(tensors: Dict[str, torch.Tensor]) -> int: |
| total = 0 |
| for tensor in tensors.values(): |
| total += tensor.element_size() * tensor.nelement() |
| return total |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| if not args.ema.is_file(): |
| raise FileNotFoundError(f"EMA weights not found: {args.ema}") |
| if not args.ae.is_file(): |
| raise FileNotFoundError(f"VAE weights not found: {args.ae}") |
|
|
| ema_state = load_safetensors(args.ema) |
| ae_state = load_safetensors(args.ae) |
|
|
| overlap = set(ae_state.keys()) & set(ema_state.keys()) |
| if overlap: |
| raise ValueError( |
| f"Found {len(overlap)} overlapping parameter names between ae and ema files; " |
| "please inspect your checkpoints before merging." |
| ) |
|
|
| merged = OrderedDict() |
| merged.update(sorted(ae_state.items())) |
| merged.update(sorted(ema_state.items())) |
|
|
| total_size = compute_total_size_bytes(merged) |
| metadata = {"total_size": str(total_size)} |
| save_safetensors(merged, args.output, metadata=metadata) |
|
|
| if args.index: |
| weight_map = {key: args.output.name for key in merged.keys()} |
| index_payload = { |
| "metadata": {"total_size": total_size}, |
| "weight_map": weight_map, |
| } |
| args.index.write_text(json.dumps(index_payload, indent=4, ensure_ascii=False) + "\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|