| import random |
|
|
| import datasets |
| import numpy as np |
| import torch |
| from datasets import DatasetDict |
| from transformers import AutoConfig |
|
|
| from dataset import MusicDataset |
| from modelling_qwen3 import MAGEL |
|
|
|
|
| def seed_everything(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def resolve_device(device_arg: str) -> torch.device: |
| if device_arg != "auto": |
| return torch.device(device_arg) |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| if torch.backends.mps.is_available(): |
| return torch.device("mps") |
| return torch.device("cpu") |
|
|
|
|
| def move_batch_to_device( |
| batch: dict[str, torch.Tensor], device: torch.device |
| ) -> dict[str, torch.Tensor]: |
| return { |
| key: value.to(device) if torch.is_tensor(value) else value |
| for key, value in batch.items() |
| } |
|
|
| def load_music_dataset( |
| dataset_path: str, |
| split: str, |
| tokenizer_path: str, |
| num_audio_token: int = 16384, |
| fps: int = 25, |
| use_fast: bool = True, |
| ) -> MusicDataset: |
| hf = datasets.load_from_disk(dataset_path) |
| if isinstance(hf, DatasetDict): |
| if split not in hf: |
| raise KeyError(f"Split not found: {split}") |
| container = hf |
| else: |
| container = {split: hf} |
| return MusicDataset( |
| datasets=container, |
| split=split, |
| tokenizer_path=tokenizer_path, |
| num_audio_token=num_audio_token, |
| fps=fps, |
| use_fast=use_fast, |
| ) |
|
|
|
|
| def load_magel_checkpoint( |
| checkpoint_path: str, |
| device: torch.device, |
| dtype: torch.dtype = torch.float32, |
| attn_implementation: str = "sdpa", |
| ) -> MAGEL: |
| config = AutoConfig.from_pretrained( |
| checkpoint_path, |
| local_files_only=True, |
| ) |
|
|
| model = MAGEL.from_pretrained( |
| checkpoint_path, |
| config=config, |
| torch_dtype=dtype, |
| attn_implementation=attn_implementation, |
| local_files_only=True, |
| ) |
| model.to(device=device) |
| model.eval() |
| return model |
|
|
|
|
| def maybe_compile_model( |
| model, |
| enabled: bool = False, |
| mode: str = "reduce-overhead", |
| ): |
| if not enabled: |
| setattr(model, "_magel_is_compiled", False) |
| return model |
| if not hasattr(torch, "compile"): |
| raise RuntimeError("torch.compile is not available in this PyTorch build.") |
| compiled_model = torch.compile(model, mode=mode) |
| setattr(compiled_model, "_magel_is_compiled", True) |
| return compiled_model |
|
|
|
|
| def maybe_mark_compile_step_begin(model) -> None: |
| if not getattr(model, "_magel_is_compiled", False): |
| return |
| compiler_ns = getattr(torch, "compiler", None) |
| if compiler_ns is None: |
| return |
| mark_step_begin = getattr(compiler_ns, "cudagraph_mark_step_begin", None) |
| if mark_step_begin is None: |
| return |
| mark_step_begin() |
|
|