techfreakworm commited on
Commit
9c07a74
·
unverified ·
1 Parent(s): adb7693

feat(lyrics): add qwen 2.5 7b lazy loader with mlx and transformers backends

Browse files
Files changed (3) hide show
  1. lyrics_lm.py +222 -0
  2. modes.py +65 -0
  3. tests/test_lyrics_lm.py +50 -0
lyrics_lm.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qwen 2.5 7B Instruct as the lyrics writer.
2
+
3
+ Mac path: ``mlx-lm`` with the 4-bit MLX quantisation (``mlx-community/
4
+ Qwen2.5-7B-Instruct-4bit``) for speed and a low VRAM footprint on Apple
5
+ Silicon. The 4-bit pack is ~4 GB on disk and runs in ~8-12 s per draft
6
+ on an M5 Max after the first warm-up.
7
+
8
+ CUDA / CPU path: ``transformers`` with the full ``Qwen/Qwen2.5-7B-Instruct``
9
+ checkpoint, ``apply_chat_template`` for the prompt, and ``do_sample=True``
10
+ generation.
11
+
12
+ Loading is lazy — the module-level ``_LM`` singleton is constructed on the
13
+ first call to ``_get_lm()`` so module import stays fast for CI and so the
14
+ ~4 GB MLX download is only triggered when the user actually clicks
15
+ "▶ Draft lyrics" in the Lyrics tab.
16
+
17
+ Tests in ``tests/test_lyrics_lm.py`` mock ``_get_lm`` at the module
18
+ boundary so the real Qwen weights are never loaded in CI.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import re
24
+ from dataclasses import dataclass
25
+ from typing import Any
26
+
27
+ import ace_pipeline as ap
28
+
29
+ _DEFAULT_MAC_ID = "mlx-community/Qwen2.5-7B-Instruct-4bit"
30
+ _DEFAULT_CUDA_ID = "Qwen/Qwen2.5-7B-Instruct"
31
+
32
+ _LM = None # lazy module-level singleton
33
+
34
+
35
+ def build_system_prompt() -> str:
36
+ """Locked songwriter system prompt for the Lyrics tab.
37
+
38
+ Returns a single multi-line string that instructs Qwen to emit ONLY
39
+ structurally-tagged lyrics (``[intro]`` ``[verse 1]`` ``[chorus]``
40
+ etc.). The exact tag vocabulary is what ACE-Step's 5Hz LM planner
41
+ expects downstream when the user pipes the draft into the Generate
42
+ tab via the "Use these in Generate" button.
43
+ """
44
+ return (
45
+ "You are a songwriter. Output ONLY structured lyrics for an AI music generator.\n"
46
+ "Use these section tags exactly: [intro] [verse 1] [verse 2] [chorus] [bridge] [outro] (etc.)\n"
47
+ "Each section is on its own line, followed by the lyrics for that section. "
48
+ "Keep verses 4-8 lines, choruses 4 lines, bridges 2-4 lines. "
49
+ "Match the requested tone and language. "
50
+ "Do not include commentary, headers, or markdown."
51
+ )
52
+
53
+
54
+ def _build_user_prompt(
55
+ brief: str,
56
+ structure: str,
57
+ language: str,
58
+ tone: str,
59
+ verse_lines: int,
60
+ chorus_lines: int,
61
+ bridge_lines: int,
62
+ rhyme: str,
63
+ ) -> str:
64
+ return (
65
+ f"Write lyrics with this structure: {structure}.\n"
66
+ f"Language: {language}. Tone: {tone or 'neutral'}. Rhyme: {rhyme}.\n"
67
+ f"Verse: {verse_lines} lines. Chorus: {chorus_lines} lines. Bridge: {bridge_lines} lines.\n\n"
68
+ f"Brief:\n{brief}\n"
69
+ )
70
+
71
+
72
+ def _normalise(text: str) -> str:
73
+ """Lowercase section tags and strip outer whitespace.
74
+
75
+ Qwen occasionally emits ``[Verse 1]`` or ``[CHORUS]`` despite the
76
+ system prompt asking for lowercase tags. ACE-Step's 5Hz LM expects
77
+ lowercase, so we coerce here rather than relying on every downstream
78
+ consumer to lowercase before parsing.
79
+ """
80
+
81
+ def lower_tag(match: re.Match[str]) -> str:
82
+ return "[" + match.group(1).lower() + "]"
83
+
84
+ return re.sub(r"\[([^\]]+)\]", lower_tag, text).strip()
85
+
86
+
87
+ def _get_lm():
88
+ """Return the lazy module-level LM singleton.
89
+
90
+ Tests in ``tests/test_lyrics_lm.py`` monkeypatch this function so
91
+ ``_load_lm()`` is never invoked under pytest. In production the
92
+ first call constructs the singleton once and caches it for the
93
+ process lifetime.
94
+ """
95
+ global _LM
96
+ if _LM is None:
97
+ _LM = _load_lm()
98
+ return _LM
99
+
100
+
101
+ def _load_lm():
102
+ """Construct the per-device LM wrapper.
103
+
104
+ On MPS we use ``mlx-lm`` which expects a model ID and returns
105
+ ``(model, tokenizer)``. On CUDA / CPU we use ``transformers`` with
106
+ ``apply_chat_template`` for the prompt.
107
+ """
108
+ device = ap.detect_device()
109
+ if device == "mps":
110
+ from mlx_lm import load # type: ignore[import-not-found]
111
+
112
+ model, tokenizer = load(_DEFAULT_MAC_ID)
113
+ return _MLXLM(model=model, tokenizer=tokenizer)
114
+
115
+ # CUDA / CPU fallback path. Use bfloat16 on CUDA, float32 on CPU.
116
+ import torch
117
+ from transformers import AutoModelForCausalLM, AutoTokenizer
118
+
119
+ tok = AutoTokenizer.from_pretrained(_DEFAULT_CUDA_ID)
120
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
121
+ model = AutoModelForCausalLM.from_pretrained(_DEFAULT_CUDA_ID, torch_dtype=dtype)
122
+ if device == "cuda":
123
+ model = model.to("cuda")
124
+ return _HFLM(model=model, tokenizer=tok)
125
+
126
+
127
+ @dataclass
128
+ class _MLXLM:
129
+ """mlx-lm wrapper. ``generate`` returns a plain string (post-decode)."""
130
+
131
+ model: Any
132
+ tokenizer: Any
133
+
134
+ def generate(self, system: str, user: str, **kw: Any) -> str:
135
+ from mlx_lm import generate # type: ignore[import-not-found]
136
+
137
+ # Qwen's ChatML template — mlx-lm doesn't expose apply_chat_template
138
+ # the way HF does, so build the prompt manually here.
139
+ prompt = (
140
+ f"<|im_start|>system\n{system}<|im_end|>\n"
141
+ f"<|im_start|>user\n{user}<|im_end|>\n"
142
+ f"<|im_start|>assistant\n"
143
+ )
144
+ return generate(
145
+ self.model,
146
+ self.tokenizer,
147
+ prompt=prompt,
148
+ max_tokens=int(kw.get("max_new_tokens", 600)),
149
+ )
150
+
151
+
152
+ @dataclass
153
+ class _HFLM:
154
+ """transformers wrapper. ``generate`` returns the assistant continuation."""
155
+
156
+ model: Any
157
+ tokenizer: Any
158
+
159
+ def generate(self, system: str, user: str, **kw: Any) -> str:
160
+ msgs = [
161
+ {"role": "system", "content": system},
162
+ {"role": "user", "content": user},
163
+ ]
164
+ prompt = self.tokenizer.apply_chat_template(
165
+ msgs,
166
+ tokenize=False,
167
+ add_generation_prompt=True,
168
+ )
169
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
170
+ out = self.model.generate(
171
+ **inputs,
172
+ max_new_tokens=int(kw.get("max_new_tokens", 600)),
173
+ temperature=float(kw.get("temperature", 0.85)),
174
+ top_p=float(kw.get("top_p", 0.9)),
175
+ top_k=int(kw.get("top_k", 40)),
176
+ repetition_penalty=float(kw.get("repetition_penalty", 1.1)),
177
+ do_sample=True,
178
+ )
179
+ full = self.tokenizer.decode(out[0], skip_special_tokens=True)
180
+ # Strip the prompt prefix so only the generated text remains.
181
+ return full[len(prompt) :] if full.startswith(prompt) else full
182
+
183
+
184
+ def generate_lyrics(
185
+ brief: str,
186
+ structure: str,
187
+ language: str,
188
+ tone: str,
189
+ verse_lines: int,
190
+ chorus_lines: int,
191
+ bridge_lines: int,
192
+ rhyme: str,
193
+ temperature: float,
194
+ top_p: float,
195
+ top_k: int,
196
+ max_new_tokens: int,
197
+ seed: int | None = None,
198
+ ) -> str:
199
+ """Draft structurally-tagged lyrics for the Lyrics tab.
200
+
201
+ Builds the user prompt from the form fields, asks the LM to generate,
202
+ and runs the output through ``_normalise()`` so section tags are
203
+ lowercase. ``seed`` is accepted for parity with the UI but is not
204
+ threaded through the mlx-lm / transformers ``generate`` calls because
205
+ neither backend's high-level ``generate(...)`` helper accepts a seed
206
+ in the version we ship with — deterministic seeding would require
207
+ dropping to the per-step token loop, which we'll add if reproducibility
208
+ becomes a hard requirement.
209
+ """
210
+ lm = _get_lm()
211
+ user = _build_user_prompt(
212
+ brief, structure, language, tone, verse_lines, chorus_lines, bridge_lines, rhyme
213
+ )
214
+ raw = lm.generate(
215
+ system=build_system_prompt(),
216
+ user=user,
217
+ temperature=temperature,
218
+ top_p=top_p,
219
+ top_k=top_k,
220
+ max_new_tokens=max_new_tokens,
221
+ )
222
+ return _normalise(raw)
modes.py CHANGED
@@ -3,12 +3,19 @@
3
  Each handler validates inputs, builds the ACE-Step kwargs for its mode, and
4
  hands off to `backend.dispatch(...)`. Backend ownership of @spaces.GPU and
5
  pipeline lifecycle keeps these handlers cheap to test.
 
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
9
 
10
  from typing import Any
11
 
 
 
12
 
13
  def _require(params: dict[str, Any], field: str) -> Any:
14
  v = params.get(field)
@@ -155,3 +162,61 @@ def edit(backend, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
155
  )
156
 
157
  return backend.dispatch(mode="edit", params=out_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  Each handler validates inputs, builds the ACE-Step kwargs for its mode, and
4
  hands off to `backend.dispatch(...)`. Backend ownership of @spaces.GPU and
5
  pipeline lifecycle keeps these handlers cheap to test.
6
+
7
+ The ``lyrics()`` handler is the odd one out: it does NOT touch the ACE-Step
8
+ backend at all. It calls ``lyrics_lm.generate_lyrics`` directly, since the
9
+ Qwen 2.5 7B LM is its own lazy singleton and doesn't share the DiT / 5Hz
10
+ pipeline lifecycle with the audio modes.
11
  """
12
 
13
  from __future__ import annotations
14
 
15
  from typing import Any
16
 
17
+ import lyrics_lm
18
+
19
 
20
  def _require(params: dict[str, Any], field: str) -> Any:
21
  v = params.get(field)
 
162
  )
163
 
164
  return backend.dispatch(mode="edit", params=out_params)
165
+
166
+
167
+ def lyrics(backend, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
168
+ """Lyrics-only mode. Returns ``(drafted_text, metadata_dict)``.
169
+
170
+ Does NOT touch the ACE-Step backend — Qwen 2.5 7B Instruct is owned
171
+ by ``lyrics_lm`` as its own lazy singleton. The ``backend`` argument
172
+ is kept in the signature for parity with the other mode handlers but
173
+ is unused here.
174
+ """
175
+ del backend # signature parity with generate/cover/extend/edit
176
+ brief = _require(params, "brief")
177
+ structure = params.get("structure", "intro, verse, chorus, verse, chorus, bridge, chorus, outro")
178
+ language = params.get("language", "en")
179
+ tone = params.get("tone", "")
180
+ verse_lines = int(params.get("verse_lines", 6))
181
+ chorus_lines = int(params.get("chorus_lines", 4))
182
+ bridge_lines = int(params.get("bridge_lines", 2))
183
+ rhyme = params.get("rhyme", "loose")
184
+ temperature = float(params.get("temperature", 0.85))
185
+ top_p = float(params.get("top_p", 0.9))
186
+ top_k = int(params.get("top_k", 40))
187
+ max_new_tokens = int(params.get("max_new_tokens", 600))
188
+ seed = params.get("seed")
189
+
190
+ text = lyrics_lm.generate_lyrics(
191
+ brief=brief,
192
+ structure=structure,
193
+ language=language,
194
+ tone=tone,
195
+ verse_lines=verse_lines,
196
+ chorus_lines=chorus_lines,
197
+ bridge_lines=bridge_lines,
198
+ rhyme=rhyme,
199
+ temperature=temperature,
200
+ top_p=top_p,
201
+ top_k=top_k,
202
+ max_new_tokens=max_new_tokens,
203
+ seed=seed,
204
+ )
205
+ meta = {
206
+ "mode": "lyrics",
207
+ "model": lyrics_lm._DEFAULT_MAC_ID,
208
+ "brief_first_line": brief.splitlines()[0] if brief else "",
209
+ "structure": structure,
210
+ "language": language,
211
+ "tone": tone,
212
+ "verse_lines": verse_lines,
213
+ "chorus_lines": chorus_lines,
214
+ "bridge_lines": bridge_lines,
215
+ "rhyme": rhyme,
216
+ "temperature": temperature,
217
+ "top_p": top_p,
218
+ "top_k": top_k,
219
+ "max_new_tokens": max_new_tokens,
220
+ "seed": seed,
221
+ }
222
+ return text, meta
tests/test_lyrics_lm.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """L2 tests for lyrics LM — generation is mocked at the model boundary.
2
+
3
+ The real Qwen 2.5 7B model is never loaded in CI. We only verify the prompt
4
+ shape, the call boundary to ``_get_lm()``, and the normalisation pass that
5
+ lowercases section tags before returning to the caller.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from unittest.mock import MagicMock
11
+
12
+ import lyrics_lm as ll
13
+
14
+
15
+ def test_build_system_prompt_includes_tag_format():
16
+ sp = ll.build_system_prompt()
17
+ low = sp.lower()
18
+ assert "[verse" in low
19
+ assert "[chorus" in low
20
+
21
+
22
+ def test_generate_lyrics_calls_lm_and_returns_text(monkeypatch):
23
+ fake_lm = MagicMock()
24
+ fake_lm.generate.return_value = "[verse] x\n[chorus] y\n"
25
+ monkeypatch.setattr(ll, "_get_lm", lambda: fake_lm)
26
+
27
+ out = ll.generate_lyrics(
28
+ brief="a song",
29
+ structure="intro, verse, chorus, outro",
30
+ language="en",
31
+ tone="upbeat",
32
+ verse_lines=4,
33
+ chorus_lines=4,
34
+ bridge_lines=2,
35
+ rhyme="loose",
36
+ temperature=0.85,
37
+ top_p=0.9,
38
+ top_k=40,
39
+ max_new_tokens=200,
40
+ seed=42,
41
+ )
42
+ assert "[verse]" in out
43
+ fake_lm.generate.assert_called_once()
44
+
45
+
46
+ def test_normalise_lyrics_lowercases_tags():
47
+ norm = ll._normalise(" [Verse 1]\nhello\n[Chorus]\nworld ")
48
+ assert "[verse 1]" in norm
49
+ assert "[chorus]" in norm
50
+ assert "[Verse" not in norm