import random import datasets import numpy as np import torch from datasets import DatasetDict from transformers import AutoConfig from dataset import MusicDataset from modelling_qwen3 import MAGEL def seed_everything(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def resolve_device(device_arg: str) -> torch.device: if device_arg != "auto": return torch.device(device_arg) if torch.cuda.is_available(): return torch.device("cuda") if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def move_batch_to_device( batch: dict[str, torch.Tensor], device: torch.device ) -> dict[str, torch.Tensor]: return { key: value.to(device) if torch.is_tensor(value) else value for key, value in batch.items() } def load_music_dataset( dataset_path: str, split: str, tokenizer_path: str, num_audio_token: int = 16384, fps: int = 25, use_fast: bool = True, ) -> MusicDataset: hf = datasets.load_from_disk(dataset_path) if isinstance(hf, DatasetDict): if split not in hf: raise KeyError(f"Split not found: {split}") container = hf else: container = {split: hf} return MusicDataset( datasets=container, split=split, tokenizer_path=tokenizer_path, num_audio_token=num_audio_token, fps=fps, use_fast=use_fast, ) def load_magel_checkpoint( checkpoint_path: str, device: torch.device, dtype: torch.dtype = torch.float32, attn_implementation: str = "sdpa", ) -> MAGEL: config = AutoConfig.from_pretrained( checkpoint_path, local_files_only=True, ) model = MAGEL.from_pretrained( checkpoint_path, config=config, torch_dtype=dtype, attn_implementation=attn_implementation, local_files_only=True, ) model.to(device=device) model.eval() return model def maybe_compile_model( model, enabled: bool = False, mode: str = "reduce-overhead", ): if not enabled: setattr(model, "_magel_is_compiled", False) return model if not hasattr(torch, "compile"): raise RuntimeError("torch.compile is not available in this PyTorch build.") compiled_model = torch.compile(model, mode=mode) setattr(compiled_model, "_magel_is_compiled", True) return compiled_model def maybe_mark_compile_step_begin(model) -> None: if not getattr(model, "_magel_is_compiled", False): return compiler_ns = getattr(torch, "compiler", None) if compiler_ns is None: return mark_step_begin = getattr(compiler_ns, "cudagraph_mark_step_begin", None) if mark_step_begin is None: return mark_step_begin()