Spaces:
Running on Zero
Running on Zero
File size: 2,848 Bytes
6215e7d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | import json
import torch
from safetensors import safe_open
from safetensors.torch import load_file
from stable_audio_3.factory import (
create_autoencoder_from_config,
create_diffusion_cond_from_config,
)
def copy_state_dict(model, state_dict):
model_state_dict = model.state_dict()
state_dict = remap_state_dict_keys(state_dict, model_state_dict)
for key in state_dict:
if (
key in model_state_dict
and state_dict[key].shape == model_state_dict[key].shape
):
model_state_dict[key] = state_dict[key]
else:
print(
f"Key {key} not found in target state_dict or shape mismatch. Skipping."
)
model.load_state_dict(model_state_dict, strict=False)
def load_autoencoder(config_path: str, ckpt_path: str, device: str = "cpu"):
"""Load only the autoencoder from a combined DiT+autoencoder checkpoint.
Only pretransform tensors are read from disk, directly onto the target device.
Standalone AE-only checkpoints (e.g. stabilityai/SAME-L / SAME-S) have no prefix.
"""
with open(config_path) as f:
config = json.load(f)
autoencoder = create_autoencoder_from_config(config["model"], config["sample_rate"])
# Full DiT checkpoints nest the AE under pretransform.model.*;
# standalone AE-only checkpoints have no prefix.
nested_prefix = "pretransform.model."
with safe_open(ckpt_path, framework="pt", device=device) as f:
all_keys = list(f.keys())
if any(k.startswith(nested_prefix) for k in all_keys):
effective_prefix = nested_prefix
else:
effective_prefix = "" # standalone AE — keys are already bare
with safe_open(ckpt_path, framework="pt", device=device) as f:
state_dict = {
k[len(effective_prefix) :]: f.get_tensor(k)
for k in all_keys
if k.startswith(effective_prefix)
}
copy_state_dict(autoencoder, state_dict)
return autoencoder.to(device)
def load_diffusion_cond(
model_config,
ckpt_path: str,
device: str = "cuda",
model_half: bool = False,
):
model = create_diffusion_cond_from_config(model_config)
copy_state_dict(model, load_file(ckpt_path))
model.to(device).eval().requires_grad_(False)
if model_half:
model.to(torch.float16)
return model
def remap_state_dict_keys(state_dict, model_state_dict):
remapped = {}
for key, value in state_dict.items():
if key not in model_state_dict:
parts = key.split(".")
for i in range(1, len(parts)):
candidate = ".".join(parts[:i]) + "." + ".".join(parts[i + 1 :])
if candidate in model_state_dict:
key = candidate
break
remapped[key] = value
return remapped
|