cond_gen / runtime_utils.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
import random
import datasets
import numpy as np
import torch
from datasets import DatasetDict
from transformers import AutoConfig
from dataset import MusicDataset
from modelling_qwen3 import MAGEL
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def resolve_device(device_arg: str) -> torch.device:
if device_arg != "auto":
return torch.device(device_arg)
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def move_batch_to_device(
batch: dict[str, torch.Tensor], device: torch.device
) -> dict[str, torch.Tensor]:
return {
key: value.to(device) if torch.is_tensor(value) else value
for key, value in batch.items()
}
def load_music_dataset(
dataset_path: str,
split: str,
tokenizer_path: str,
num_audio_token: int = 16384,
fps: int = 25,
use_fast: bool = True,
) -> MusicDataset:
hf = datasets.load_from_disk(dataset_path)
if isinstance(hf, DatasetDict):
if split not in hf:
raise KeyError(f"Split not found: {split}")
container = hf
else:
container = {split: hf}
return MusicDataset(
datasets=container,
split=split,
tokenizer_path=tokenizer_path,
num_audio_token=num_audio_token,
fps=fps,
use_fast=use_fast,
)
def load_magel_checkpoint(
checkpoint_path: str,
device: torch.device,
dtype: torch.dtype = torch.float32,
attn_implementation: str = "sdpa",
) -> MAGEL:
config = AutoConfig.from_pretrained(
checkpoint_path,
local_files_only=True,
)
model = MAGEL.from_pretrained(
checkpoint_path,
config=config,
torch_dtype=dtype,
attn_implementation=attn_implementation,
local_files_only=True,
)
model.to(device=device)
model.eval()
return model
def maybe_compile_model(
model,
enabled: bool = False,
mode: str = "reduce-overhead",
):
if not enabled:
setattr(model, "_magel_is_compiled", False)
return model
if not hasattr(torch, "compile"):
raise RuntimeError("torch.compile is not available in this PyTorch build.")
compiled_model = torch.compile(model, mode=mode)
setattr(compiled_model, "_magel_is_compiled", True)
return compiled_model
def maybe_mark_compile_step_begin(model) -> None:
if not getattr(model, "_magel_is_compiled", False):
return
compiler_ns = getattr(torch, "compiler", None)
if compiler_ns is None:
return
mark_step_begin = getattr(compiler_ns, "cudagraph_mark_step_begin", None)
if mark_step_begin is None:
return
mark_step_begin()