| """Helper script to convert diffusion checkpoints to format used by image generator.""" |
|
|
| import os |
|
|
| from absl import app |
| from absl import flags |
| import requests |
| import torch as th |
|
|
|
|
| _CKPT_PATH = flags.DEFINE_string( |
| "ckpt_path", default=None, help="Path to checkpoint file", required=True) |
| _OUTPUT_PATH = flags.DEFINE_string( |
| "output_path", default="bins", help="Output folder path", required=False) |
|
|
| VOCAB_URL = "https://openaipublic.blob.core.windows.net/clip/bpe_simple_vocab_16e6.txt" |
|
|
|
|
| def run(ckpt_path, output_path): |
| """Converts the checkpoint and saves the result. |
| |
| Args: |
| ckpt_path: Source checkpoint path |
| output_path: Result folder directory |
| """ |
| os.makedirs(output_path, exist_ok=True) |
| ckpt = th.load(ckpt_path, map_location="cpu") |
|
|
| vocab_dest = os.path.join(output_path, os.path.basename(VOCAB_URL)) |
| if not os.path.exists(vocab_dest): |
| with requests.get(VOCAB_URL, stream=True) as response: |
| with open(vocab_dest, "wb") as vocab_file: |
| for c in response.iter_content(chunk_size=8192): |
| vocab_file.write(c) |
|
|
| for k, v in ckpt["state_dict"].items(): |
| if "first_stage_model.encoder" in k: |
| continue |
| if not hasattr(v, "numpy"): |
| continue |
| output_bin_file = os.path.join(output_path, f"{k}.bin") |
| v.numpy().astype("float16").tofile(output_bin_file) |
|
|
|
|
| def main(_) -> None: |
| ckpt_path = _CKPT_PATH.value |
| output_path = _OUTPUT_PATH.value |
| run(ckpt_path, output_path) |
|
|
| if __name__ == "__main__": |
| app.run(main) |
|
|