#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Train MAGEL directly from a vanilla Qwen3 checkpoint. Compared with train.py/train_newparaonly.py, this script: 1) Loads an original Qwen3 base checkpoint. 2) Resolves MAGEL hparams explicitly at construction time. 3) Initializes MAGEL extra modules from scratch and trains end-to-end. """ import argparse import os import torch from transformers import ( AutoConfig, Trainer, TrainingArguments, ) import datasets from dataset import DataCollate, MusicDataset from modelling_qwen3 import MAGEL def resolve_model_source(model_path: str, resume_from_checkpoint: str | None) -> str: if not resume_from_checkpoint: return model_path if os.path.abspath(model_path) != os.path.abspath(resume_from_checkpoint): print( "Ignoring --model_path during resume and loading config/model from: " f"{resume_from_checkpoint}" ) return resume_from_checkpoint def create_model( model_path: str, model_dtype: torch.dtype, target_vocab_size: int, attn_implementation: str, ) -> MAGEL: print(f"Loading Qwen3 model from: {model_path}") config = AutoConfig.from_pretrained( model_path, local_files_only=True, ) model = MAGEL.from_pretrained( model_path, torch_dtype=model_dtype, config=config, attn_implementation=attn_implementation, ignore_mismatched_sizes=True, local_files_only=True, ) model.resize_token_embeddings(target_vocab_size) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) magel_extra_params = sum( p.numel() for name, p in model.named_parameters() if ("condition_encoder" in name or "dit_adaln" in name) ) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") print(f"MAGEL extra parameters: {magel_extra_params:,}") print( "MAGEL config: " f"adaln_dim={model.adaln_dim}, " f"chord_dropout_trigger_prob={model.chord_dropout_trigger_prob}, " f"structure_dropout_trigger_prob={model.structure_dropout_trigger_prob}" ) return model def create_dataset( dataset_path: str, tokenizer_path: str, num_audio_token: int = 16384, ) -> MusicDataset: print(f"Loading dataset from: {dataset_path}") print(f"Loading tokenizer from: {tokenizer_path}") hf_ds = datasets.load_from_disk(dataset_path) train_dataset = MusicDataset( hf_ds, split="train", tokenizer_path=tokenizer_path, num_audio_token=num_audio_token, use_fast=True, ) print(f"Dataset size: {len(train_dataset)}") return train_dataset def main(): parser = argparse.ArgumentParser( description="Train MAGEL directly from a vanilla Qwen3 base checkpoint." ) parser.add_argument( "--dataset_path", type=str, default="muse_mucodec_chord.ds", ) parser.add_argument( "--model_path", type=str, default="checkpoints/Qwen3-0.6B", help="Local Qwen3 base checkpoint path.", ) parser.add_argument( "--tokenizer_path", type=str, default="checkpoints/Qwen3-0.6B", help="Local tokenizer checkpoint path.", ) parser.add_argument( "--model_dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], ) parser.add_argument( "--attn_implementation", type=str, default="sdpa", choices=["eager", "sdpa", "flash_attention_2"], ) parser.add_argument("--output_dir", type=str, default="./output_qwen3_0p6b_train") parser.add_argument("--per_device_train_batch_size", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--num_train_epochs", type=float, default=20) parser.add_argument("--warmup_steps", type=int, default=1000) parser.add_argument("--max_grad_norm", type=float, default=5.0) parser.add_argument("--logging_steps", type=int, default=10) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help="Resume training from a Trainer checkpoint directory such as output_dir/checkpoint-500.", ) parser.add_argument("--dataloader_num_workers", type=int, default=12) parser.add_argument( "--gradient_checkpointing", dest="gradient_checkpointing", action="store_true", ) parser.add_argument( "--deepspeed", type=str, default=None, help="Path to DeepSpeed config. Leave unset to disable DeepSpeed.", ) parser.add_argument("--report_to", type=str, default="wandb") parser.add_argument("--wandb_project", type=str, default="vaultum-qwen3-0p6b") parser.add_argument("--wandb_run_name", type=str, default=None) args = parser.parse_args() model_dtype = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, }[args.model_dtype] model_source = resolve_model_source( model_path=args.model_path, resume_from_checkpoint=args.resume_from_checkpoint, ) base_config = AutoConfig.from_pretrained( model_source, local_files_only=True, ) num_audio_token = int(base_config.magel_num_audio_token) print(f"Using num_audio_token={num_audio_token}") train_dataset = create_dataset( dataset_path=args.dataset_path, tokenizer_path=args.tokenizer_path, num_audio_token=num_audio_token, ) target_vocab_size = train_dataset.tokenizer_vocab_size model = create_model( model_path=model_source, model_dtype=model_dtype, attn_implementation=args.attn_implementation, target_vocab_size=target_vocab_size, ) training_args = TrainingArguments( output_dir=args.output_dir, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, num_train_epochs=args.num_train_epochs, warmup_steps=args.warmup_steps, max_grad_norm=args.max_grad_norm, logging_steps=args.logging_steps, save_strategy="epoch", dataloader_num_workers=args.dataloader_num_workers, bf16=(args.model_dtype == "bfloat16"), fp16=(args.model_dtype == "float16"), gradient_checkpointing=args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, deepspeed=args.deepspeed, remove_unused_columns=False, dataloader_drop_last=True, report_to=args.report_to, logging_dir=None, run_name=args.wandb_run_name, ) if args.wandb_project and "wandb" in args.report_to: os.environ["WANDB_PROJECT"] = args.wandb_project trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, data_collator=DataCollate(), ) if args.resume_from_checkpoint: print(f"Resuming training from checkpoint: {args.resume_from_checkpoint}") else: print("Starting training from current model initialization.") trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) final_dir = os.path.join(args.output_dir, "final") trainer.save_model(final_dir) train_dataset.tokenizer.save_pretrained(final_dir) print(f"Training complete. Final model saved to: {final_dir}") if __name__ == "__main__": main()