File size: 9,165 Bytes
9c07a74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a81cc03
0d2d5a8
9c07a74
 
 
 
 
 
 
 
 
a81cc03
0d2d5a8
 
 
 
 
 
 
 
 
 
 
 
 
 
a81cc03
0d2d5a8
a81cc03
 
 
 
 
 
9c07a74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03ddd85
 
 
 
 
 
 
 
9c07a74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""Qwen 2.5 7B Instruct as the lyrics writer.

Mac path: ``mlx-lm`` with the 4-bit MLX quantisation (``mlx-community/
Qwen2.5-7B-Instruct-4bit``) for speed and a low VRAM footprint on Apple
Silicon. The 4-bit pack is ~4 GB on disk and runs in ~8-12 s per draft
on an M5 Max after the first warm-up.

CUDA / CPU path: ``transformers`` with the full ``Qwen/Qwen2.5-7B-Instruct``
checkpoint, ``apply_chat_template`` for the prompt, and ``do_sample=True``
generation.

Loading is lazy — the module-level ``_LM`` singleton is constructed on the
first call to ``_get_lm()`` so module import stays fast for CI and so the
~4 GB MLX download is only triggered when the user actually clicks
"▶ Draft lyrics" in the Lyrics tab.

Tests in ``tests/test_lyrics_lm.py`` mock ``_get_lm`` at the module
boundary so the real Qwen weights are never loaded in CI.
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Any

import ace_pipeline as ap

_DEFAULT_MAC_ID = "mlx-community/Qwen2.5-7B-Instruct-4bit"
_DEFAULT_CUDA_ID = "Qwen/Qwen2.5-7B-Instruct"

_LM = None  # lazy module-level singleton


def build_system_prompt() -> str:
    """Locked songwriter system prompt for the Lyrics tab.

    Returns a single multi-line string that instructs Qwen to emit ONLY
    structurally-tagged lyrics (``[intro]`` ``[verse 1]`` ``[chorus]``
    etc.). The exact tag vocabulary is what ACE-Step's 5Hz LM planner
    expects downstream when the user pipes the draft into the Generate
    tab via the "Use these in Generate" button.
    """
    return (
        "You are a songwriter. Output ONLY structured lyrics for an AI music generator.\n"
        "Use these section tags exactly: [intro] [verse 1] [verse 2] [chorus] [bridge] [outro] (etc.)\n"
        "Each section is on its own line, followed by the lyrics for that section. "
        "Keep verses 4-8 lines, choruses 4 lines, bridges 2-4 lines. "
        "Match the requested tone and language. "
        "Do not include commentary, headers, or markdown."
    )


def _build_user_prompt(
    brief: str,
    structure: str,
    language: str,
    tone: str,
    verse_lines: int,
    chorus_lines: int,
    bridge_lines: int,
    rhyme: str,
) -> str:
    return (
        f"Write lyrics with this structure: {structure}.\n"
        f"Language: {language}. Tone: {tone or 'neutral'}. Rhyme: {rhyme}.\n"
        f"Verse: {verse_lines} lines. Chorus: {chorus_lines} lines. Bridge: {bridge_lines} lines.\n\n"
        f"Brief:\n{brief}\n"
    )


def _normalise(text: str) -> str:
    """Lowercase section tags and strip outer whitespace.

    Qwen occasionally emits ``[Verse 1]`` or ``[CHORUS]`` despite the
    system prompt asking for lowercase tags. ACE-Step's 5Hz LM expects
    lowercase, so we coerce here rather than relying on every downstream
    consumer to lowercase before parsing.
    """

    def lower_tag(match: re.Match[str]) -> str:
        return "[" + match.group(1).lower() + "]"

    return re.sub(r"\[([^\]]+)\]", lower_tag, text).strip()


def _get_lm():
    """Return the lazy module-level LM singleton.

    Tests in ``tests/test_lyrics_lm.py`` monkeypatch this function so
    ``_load_lm()`` is never invoked under pytest. In production the
    first call constructs the singleton once and caches it for the
    process lifetime.
    """
    global _LM
    if _LM is None:
        _LM = _load_lm()
    return _LM


def _load_lm():
    """Construct the per-device LM wrapper.

    On MPS we use ``mlx-lm`` which expects a model ID and returns
    ``(model, tokenizer)``. On CUDA / CPU we use ``transformers`` with
    ``apply_chat_template`` for the prompt.
    """
    device = ap.detect_device()
    if device == "mps":
        from mlx_lm import load  # type: ignore[import-not-found]

        model, tokenizer = load(_DEFAULT_MAC_ID)
        return _MLXLM(model=model, tokenizer=tokenizer)

    # CUDA / CPU fallback path. Use bfloat16 on CUDA, float32 on CPU.
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    tok = AutoTokenizer.from_pretrained(_DEFAULT_CUDA_ID)
    dtype = torch.bfloat16 if device == "cuda" else torch.float32
    model = AutoModelForCausalLM.from_pretrained(_DEFAULT_CUDA_ID, torch_dtype=dtype)
    if device == "cuda":
        model = model.to("cuda")
    return _HFLM(model=model, tokenizer=tok)


@dataclass
class _MLXLM:
    """mlx-lm wrapper. ``generate`` returns a plain string (post-decode)."""

    model: Any
    tokenizer: Any

    def generate(self, system: str, user: str, **kw: Any) -> str:
        import mlx.core as mx  # type: ignore[import-not-found]
        import mlx_lm.generate as mlx_gen_mod  # type: ignore[import-not-found]
        from mlx_lm import generate  # type: ignore[import-not-found]

        # Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
        # the way HF does, so build the prompt manually here.
        prompt = (
            f"<|im_start|>system\n{system}<|im_end|>\n"
            f"<|im_start|>user\n{user}<|im_end|>\n"
            f"<|im_start|>assistant\n"
        )
        # Gradio runs handlers in anyio worker threads. MLX maintains a
        # *per-thread* default stream and a module-level ``generation_stream``
        # that was created at mlx_lm import time on the MAIN thread. Both
        # need to be valid in the *current* (worker) thread or
        # ``wired_limit().__exit__`` crashes with "There is no Stream(gpu, 0)
        # in current thread" when it calls ``mx.synchronize(generation_stream)``.
        #
        # Two-part fix:
        #   1. ``mx.stream(mx.gpu)`` wrap installs the default GPU stream
        #      for the current thread for the duration of the call.
        #   2. Re-assign ``mlx_lm.generate.generation_stream`` to a stream
        #      created in the *current* thread so ``mx.synchronize`` doesn't
        #      reach across thread boundaries. The reassignment is safe
        #      because Gradio's queue runs at default_concurrency_limit=1 —
        #      no two lyrics drafts run concurrently.
        with mx.stream(mx.gpu):
            mlx_gen_mod.generation_stream = mx.new_stream(mx.default_device())
            return generate(
                self.model,
                self.tokenizer,
                prompt=prompt,
                max_tokens=int(kw.get("max_new_tokens", 600)),
            )


@dataclass
class _HFLM:
    """transformers wrapper. ``generate`` returns the assistant continuation."""

    model: Any
    tokenizer: Any

    def generate(self, system: str, user: str, **kw: Any) -> str:
        msgs = [
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ]
        prompt = self.tokenizer.apply_chat_template(
            msgs,
            tokenize=False,
            add_generation_prompt=True,
        )
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        out = self.model.generate(
            **inputs,
            max_new_tokens=int(kw.get("max_new_tokens", 600)),
            temperature=float(kw.get("temperature", 0.85)),
            top_p=float(kw.get("top_p", 0.9)),
            top_k=int(kw.get("top_k", 40)),
            repetition_penalty=float(kw.get("repetition_penalty", 1.1)),
            do_sample=True,
        )
        # Slice off the prompt tokens at the *token* level. Doing it at the
        # string level (full.startswith(prompt)) is brittle because
        # ``skip_special_tokens=True`` strips the ChatML markers from
        # ``full`` but they're still present in ``prompt`` — so the prefix
        # match fails and the system + user turns leak into the output.
        prompt_len = int(inputs["input_ids"].shape[1])
        generated_ids = out[0][prompt_len:]
        return self.tokenizer.decode(generated_ids, skip_special_tokens=True)


def generate_lyrics(
    brief: str,
    structure: str,
    language: str,
    tone: str,
    verse_lines: int,
    chorus_lines: int,
    bridge_lines: int,
    rhyme: str,
    temperature: float,
    top_p: float,
    top_k: int,
    max_new_tokens: int,
    seed: int | None = None,
) -> str:
    """Draft structurally-tagged lyrics for the Lyrics tab.

    Builds the user prompt from the form fields, asks the LM to generate,
    and runs the output through ``_normalise()`` so section tags are
    lowercase. ``seed`` is accepted for parity with the UI but is not
    threaded through the mlx-lm / transformers ``generate`` calls because
    neither backend's high-level ``generate(...)`` helper accepts a seed
    in the version we ship with — deterministic seeding would require
    dropping to the per-step token loop, which we'll add if reproducibility
    becomes a hard requirement.
    """
    lm = _get_lm()
    user = _build_user_prompt(
        brief, structure, language, tone, verse_lines, chorus_lines, bridge_lines, rhyme
    )
    raw = lm.generate(
        system=build_system_prompt(),
        user=user,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        max_new_tokens=max_new_tokens,
    )
    return _normalise(raw)