ACE-Music-Studio / lyrics_lm.py
techfreakworm's picture
fix(deploy): symlink hf preloads into fork's checkpoint dir + strip prompt tokens in lyrics
03ddd85 unverified
"""Qwen 2.5 7B Instruct as the lyrics writer.
Mac path: ``mlx-lm`` with the 4-bit MLX quantisation (``mlx-community/
Qwen2.5-7B-Instruct-4bit``) for speed and a low VRAM footprint on Apple
Silicon. The 4-bit pack is ~4 GB on disk and runs in ~8-12 s per draft
on an M5 Max after the first warm-up.
CUDA / CPU path: ``transformers`` with the full ``Qwen/Qwen2.5-7B-Instruct``
checkpoint, ``apply_chat_template`` for the prompt, and ``do_sample=True``
generation.
Loading is lazy — the module-level ``_LM`` singleton is constructed on the
first call to ``_get_lm()`` so module import stays fast for CI and so the
~4 GB MLX download is only triggered when the user actually clicks
"▶ Draft lyrics" in the Lyrics tab.
Tests in ``tests/test_lyrics_lm.py`` mock ``_get_lm`` at the module
boundary so the real Qwen weights are never loaded in CI.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
import ace_pipeline as ap
_DEFAULT_MAC_ID = "mlx-community/Qwen2.5-7B-Instruct-4bit"
_DEFAULT_CUDA_ID = "Qwen/Qwen2.5-7B-Instruct"
_LM = None # lazy module-level singleton
def build_system_prompt() -> str:
"""Locked songwriter system prompt for the Lyrics tab.
Returns a single multi-line string that instructs Qwen to emit ONLY
structurally-tagged lyrics (``[intro]`` ``[verse 1]`` ``[chorus]``
etc.). The exact tag vocabulary is what ACE-Step's 5Hz LM planner
expects downstream when the user pipes the draft into the Generate
tab via the "Use these in Generate" button.
"""
return (
"You are a songwriter. Output ONLY structured lyrics for an AI music generator.\n"
"Use these section tags exactly: [intro] [verse 1] [verse 2] [chorus] [bridge] [outro] (etc.)\n"
"Each section is on its own line, followed by the lyrics for that section. "
"Keep verses 4-8 lines, choruses 4 lines, bridges 2-4 lines. "
"Match the requested tone and language. "
"Do not include commentary, headers, or markdown."
)
def _build_user_prompt(
brief: str,
structure: str,
language: str,
tone: str,
verse_lines: int,
chorus_lines: int,
bridge_lines: int,
rhyme: str,
) -> str:
return (
f"Write lyrics with this structure: {structure}.\n"
f"Language: {language}. Tone: {tone or 'neutral'}. Rhyme: {rhyme}.\n"
f"Verse: {verse_lines} lines. Chorus: {chorus_lines} lines. Bridge: {bridge_lines} lines.\n\n"
f"Brief:\n{brief}\n"
)
def _normalise(text: str) -> str:
"""Lowercase section tags and strip outer whitespace.
Qwen occasionally emits ``[Verse 1]`` or ``[CHORUS]`` despite the
system prompt asking for lowercase tags. ACE-Step's 5Hz LM expects
lowercase, so we coerce here rather than relying on every downstream
consumer to lowercase before parsing.
"""
def lower_tag(match: re.Match[str]) -> str:
return "[" + match.group(1).lower() + "]"
return re.sub(r"\[([^\]]+)\]", lower_tag, text).strip()
def _get_lm():
"""Return the lazy module-level LM singleton.
Tests in ``tests/test_lyrics_lm.py`` monkeypatch this function so
``_load_lm()`` is never invoked under pytest. In production the
first call constructs the singleton once and caches it for the
process lifetime.
"""
global _LM
if _LM is None:
_LM = _load_lm()
return _LM
def _load_lm():
"""Construct the per-device LM wrapper.
On MPS we use ``mlx-lm`` which expects a model ID and returns
``(model, tokenizer)``. On CUDA / CPU we use ``transformers`` with
``apply_chat_template`` for the prompt.
"""
device = ap.detect_device()
if device == "mps":
from mlx_lm import load # type: ignore[import-not-found]
model, tokenizer = load(_DEFAULT_MAC_ID)
return _MLXLM(model=model, tokenizer=tokenizer)
# CUDA / CPU fallback path. Use bfloat16 on CUDA, float32 on CPU.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(_DEFAULT_CUDA_ID)
dtype = torch.bfloat16 if device == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(_DEFAULT_CUDA_ID, torch_dtype=dtype)
if device == "cuda":
model = model.to("cuda")
return _HFLM(model=model, tokenizer=tok)
@dataclass
class _MLXLM:
"""mlx-lm wrapper. ``generate`` returns a plain string (post-decode)."""
model: Any
tokenizer: Any
def generate(self, system: str, user: str, **kw: Any) -> str:
import mlx.core as mx # type: ignore[import-not-found]
import mlx_lm.generate as mlx_gen_mod # type: ignore[import-not-found]
from mlx_lm import generate # type: ignore[import-not-found]
# Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
# the way HF does, so build the prompt manually here.
prompt = (
f"<|im_start|>system\n{system}<|im_end|>\n"
f"<|im_start|>user\n{user}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# Gradio runs handlers in anyio worker threads. MLX maintains a
# *per-thread* default stream and a module-level ``generation_stream``
# that was created at mlx_lm import time on the MAIN thread. Both
# need to be valid in the *current* (worker) thread or
# ``wired_limit().__exit__`` crashes with "There is no Stream(gpu, 0)
# in current thread" when it calls ``mx.synchronize(generation_stream)``.
#
# Two-part fix:
# 1. ``mx.stream(mx.gpu)`` wrap installs the default GPU stream
# for the current thread for the duration of the call.
# 2. Re-assign ``mlx_lm.generate.generation_stream`` to a stream
# created in the *current* thread so ``mx.synchronize`` doesn't
# reach across thread boundaries. The reassignment is safe
# because Gradio's queue runs at default_concurrency_limit=1 —
# no two lyrics drafts run concurrently.
with mx.stream(mx.gpu):
mlx_gen_mod.generation_stream = mx.new_stream(mx.default_device())
return generate(
self.model,
self.tokenizer,
prompt=prompt,
max_tokens=int(kw.get("max_new_tokens", 600)),
)
@dataclass
class _HFLM:
"""transformers wrapper. ``generate`` returns the assistant continuation."""
model: Any
tokenizer: Any
def generate(self, system: str, user: str, **kw: Any) -> str:
msgs = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
prompt = self.tokenizer.apply_chat_template(
msgs,
tokenize=False,
add_generation_prompt=True,
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
out = self.model.generate(
**inputs,
max_new_tokens=int(kw.get("max_new_tokens", 600)),
temperature=float(kw.get("temperature", 0.85)),
top_p=float(kw.get("top_p", 0.9)),
top_k=int(kw.get("top_k", 40)),
repetition_penalty=float(kw.get("repetition_penalty", 1.1)),
do_sample=True,
)
# Slice off the prompt tokens at the *token* level. Doing it at the
# string level (full.startswith(prompt)) is brittle because
# ``skip_special_tokens=True`` strips the ChatML markers from
# ``full`` but they're still present in ``prompt`` — so the prefix
# match fails and the system + user turns leak into the output.
prompt_len = int(inputs["input_ids"].shape[1])
generated_ids = out[0][prompt_len:]
return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
def generate_lyrics(
brief: str,
structure: str,
language: str,
tone: str,
verse_lines: int,
chorus_lines: int,
bridge_lines: int,
rhyme: str,
temperature: float,
top_p: float,
top_k: int,
max_new_tokens: int,
seed: int | None = None,
) -> str:
"""Draft structurally-tagged lyrics for the Lyrics tab.
Builds the user prompt from the form fields, asks the LM to generate,
and runs the output through ``_normalise()`` so section tags are
lowercase. ``seed`` is accepted for parity with the UI but is not
threaded through the mlx-lm / transformers ``generate`` calls because
neither backend's high-level ``generate(...)`` helper accepts a seed
in the version we ship with — deterministic seeding would require
dropping to the per-step token loop, which we'll add if reproducibility
becomes a hard requirement.
"""
lm = _get_lm()
user = _build_user_prompt(
brief, structure, language, tone, verse_lines, chorus_lines, bridge_lines, rhyme
)
raw = lm.generate(
system=build_system_prompt(),
user=user,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
)
return _normalise(raw)