Spaces:
Running on Zero
Running on Zero
File size: 9,165 Bytes
9c07a74 a81cc03 0d2d5a8 9c07a74 a81cc03 0d2d5a8 a81cc03 0d2d5a8 a81cc03 9c07a74 03ddd85 9c07a74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | """Qwen 2.5 7B Instruct as the lyrics writer.
Mac path: ``mlx-lm`` with the 4-bit MLX quantisation (``mlx-community/
Qwen2.5-7B-Instruct-4bit``) for speed and a low VRAM footprint on Apple
Silicon. The 4-bit pack is ~4 GB on disk and runs in ~8-12 s per draft
on an M5 Max after the first warm-up.
CUDA / CPU path: ``transformers`` with the full ``Qwen/Qwen2.5-7B-Instruct``
checkpoint, ``apply_chat_template`` for the prompt, and ``do_sample=True``
generation.
Loading is lazy — the module-level ``_LM`` singleton is constructed on the
first call to ``_get_lm()`` so module import stays fast for CI and so the
~4 GB MLX download is only triggered when the user actually clicks
"▶ Draft lyrics" in the Lyrics tab.
Tests in ``tests/test_lyrics_lm.py`` mock ``_get_lm`` at the module
boundary so the real Qwen weights are never loaded in CI.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
import ace_pipeline as ap
_DEFAULT_MAC_ID = "mlx-community/Qwen2.5-7B-Instruct-4bit"
_DEFAULT_CUDA_ID = "Qwen/Qwen2.5-7B-Instruct"
_LM = None # lazy module-level singleton
def build_system_prompt() -> str:
"""Locked songwriter system prompt for the Lyrics tab.
Returns a single multi-line string that instructs Qwen to emit ONLY
structurally-tagged lyrics (``[intro]`` ``[verse 1]`` ``[chorus]``
etc.). The exact tag vocabulary is what ACE-Step's 5Hz LM planner
expects downstream when the user pipes the draft into the Generate
tab via the "Use these in Generate" button.
"""
return (
"You are a songwriter. Output ONLY structured lyrics for an AI music generator.\n"
"Use these section tags exactly: [intro] [verse 1] [verse 2] [chorus] [bridge] [outro] (etc.)\n"
"Each section is on its own line, followed by the lyrics for that section. "
"Keep verses 4-8 lines, choruses 4 lines, bridges 2-4 lines. "
"Match the requested tone and language. "
"Do not include commentary, headers, or markdown."
)
def _build_user_prompt(
brief: str,
structure: str,
language: str,
tone: str,
verse_lines: int,
chorus_lines: int,
bridge_lines: int,
rhyme: str,
) -> str:
return (
f"Write lyrics with this structure: {structure}.\n"
f"Language: {language}. Tone: {tone or 'neutral'}. Rhyme: {rhyme}.\n"
f"Verse: {verse_lines} lines. Chorus: {chorus_lines} lines. Bridge: {bridge_lines} lines.\n\n"
f"Brief:\n{brief}\n"
)
def _normalise(text: str) -> str:
"""Lowercase section tags and strip outer whitespace.
Qwen occasionally emits ``[Verse 1]`` or ``[CHORUS]`` despite the
system prompt asking for lowercase tags. ACE-Step's 5Hz LM expects
lowercase, so we coerce here rather than relying on every downstream
consumer to lowercase before parsing.
"""
def lower_tag(match: re.Match[str]) -> str:
return "[" + match.group(1).lower() + "]"
return re.sub(r"\[([^\]]+)\]", lower_tag, text).strip()
def _get_lm():
"""Return the lazy module-level LM singleton.
Tests in ``tests/test_lyrics_lm.py`` monkeypatch this function so
``_load_lm()`` is never invoked under pytest. In production the
first call constructs the singleton once and caches it for the
process lifetime.
"""
global _LM
if _LM is None:
_LM = _load_lm()
return _LM
def _load_lm():
"""Construct the per-device LM wrapper.
On MPS we use ``mlx-lm`` which expects a model ID and returns
``(model, tokenizer)``. On CUDA / CPU we use ``transformers`` with
``apply_chat_template`` for the prompt.
"""
device = ap.detect_device()
if device == "mps":
from mlx_lm import load # type: ignore[import-not-found]
model, tokenizer = load(_DEFAULT_MAC_ID)
return _MLXLM(model=model, tokenizer=tokenizer)
# CUDA / CPU fallback path. Use bfloat16 on CUDA, float32 on CPU.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(_DEFAULT_CUDA_ID)
dtype = torch.bfloat16 if device == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(_DEFAULT_CUDA_ID, torch_dtype=dtype)
if device == "cuda":
model = model.to("cuda")
return _HFLM(model=model, tokenizer=tok)
@dataclass
class _MLXLM:
"""mlx-lm wrapper. ``generate`` returns a plain string (post-decode)."""
model: Any
tokenizer: Any
def generate(self, system: str, user: str, **kw: Any) -> str:
import mlx.core as mx # type: ignore[import-not-found]
import mlx_lm.generate as mlx_gen_mod # type: ignore[import-not-found]
from mlx_lm import generate # type: ignore[import-not-found]
# Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
# the way HF does, so build the prompt manually here.
prompt = (
f"<|im_start|>system\n{system}<|im_end|>\n"
f"<|im_start|>user\n{user}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# Gradio runs handlers in anyio worker threads. MLX maintains a
# *per-thread* default stream and a module-level ``generation_stream``
# that was created at mlx_lm import time on the MAIN thread. Both
# need to be valid in the *current* (worker) thread or
# ``wired_limit().__exit__`` crashes with "There is no Stream(gpu, 0)
# in current thread" when it calls ``mx.synchronize(generation_stream)``.
#
# Two-part fix:
# 1. ``mx.stream(mx.gpu)`` wrap installs the default GPU stream
# for the current thread for the duration of the call.
# 2. Re-assign ``mlx_lm.generate.generation_stream`` to a stream
# created in the *current* thread so ``mx.synchronize`` doesn't
# reach across thread boundaries. The reassignment is safe
# because Gradio's queue runs at default_concurrency_limit=1 —
# no two lyrics drafts run concurrently.
with mx.stream(mx.gpu):
mlx_gen_mod.generation_stream = mx.new_stream(mx.default_device())
return generate(
self.model,
self.tokenizer,
prompt=prompt,
max_tokens=int(kw.get("max_new_tokens", 600)),
)
@dataclass
class _HFLM:
"""transformers wrapper. ``generate`` returns the assistant continuation."""
model: Any
tokenizer: Any
def generate(self, system: str, user: str, **kw: Any) -> str:
msgs = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
prompt = self.tokenizer.apply_chat_template(
msgs,
tokenize=False,
add_generation_prompt=True,
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
out = self.model.generate(
**inputs,
max_new_tokens=int(kw.get("max_new_tokens", 600)),
temperature=float(kw.get("temperature", 0.85)),
top_p=float(kw.get("top_p", 0.9)),
top_k=int(kw.get("top_k", 40)),
repetition_penalty=float(kw.get("repetition_penalty", 1.1)),
do_sample=True,
)
# Slice off the prompt tokens at the *token* level. Doing it at the
# string level (full.startswith(prompt)) is brittle because
# ``skip_special_tokens=True`` strips the ChatML markers from
# ``full`` but they're still present in ``prompt`` — so the prefix
# match fails and the system + user turns leak into the output.
prompt_len = int(inputs["input_ids"].shape[1])
generated_ids = out[0][prompt_len:]
return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
def generate_lyrics(
brief: str,
structure: str,
language: str,
tone: str,
verse_lines: int,
chorus_lines: int,
bridge_lines: int,
rhyme: str,
temperature: float,
top_p: float,
top_k: int,
max_new_tokens: int,
seed: int | None = None,
) -> str:
"""Draft structurally-tagged lyrics for the Lyrics tab.
Builds the user prompt from the form fields, asks the LM to generate,
and runs the output through ``_normalise()`` so section tags are
lowercase. ``seed`` is accepted for parity with the UI but is not
threaded through the mlx-lm / transformers ``generate`` calls because
neither backend's high-level ``generate(...)`` helper accepts a seed
in the version we ship with — deterministic seeding would require
dropping to the per-step token loop, which we'll add if reproducibility
becomes a hard requirement.
"""
lm = _get_lm()
user = _build_user_prompt(
brief, structure, language, tone, verse_lines, chorus_lines, bridge_lines, rhyme
)
raw = lm.generate(
system=build_system_prompt(),
user=user,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
)
return _normalise(raw)
|