File size: 7,760 Bytes
7e0eb32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""Prompt chunking for long-form DramaBox generation.

The base LTX-2.3 audio DiT was trained on clips <= ~20 s. The silence-prior
patch in ``inference_server.py`` keeps generations sane up to ~45 s, but the
prior re-emerges past that boundary. For arbitrary-length prompts we split the
text into < 45 s chunks, generate each conditioned on the same voice reference,
and crossfade them back together.

Chunking is quote-aware (sentence terminators inside ``"..."`` don't count)
and preserves the speaker-description prefix on every chunk so the model keeps
the same persona / delivery style across joins.
"""
from __future__ import annotations

import re
from dataclasses import dataclass
from typing import List, Optional


# Matches the leading speaker description, ending at the first comma that's
# directly followed by a space + opening quote. Anything before that is treated
# as persona/style metadata and re-attached to every chunk.
#   "A shadowy villain speaks with cold menace, \"You have entered...\""
#    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_PREFIX_RE = re.compile(r'^([^"\']{3,}?)(,\s*)(?=["\'])', re.DOTALL)


@dataclass
class PromptChunk:
    text: str
    est_duration_s: float


def extract_speaker_prefix(prompt: str) -> tuple[Optional[str], str]:
    """Return ``(prefix, body)`` where ``prefix`` is the speaker description.

    If the prompt has the canonical ``"<persona>, "<dialogue>"..."`` form, the
    persona (without the trailing comma) is returned as the prefix and the rest
    of the prompt as the body. Otherwise ``(None, prompt)`` — no prefix to
    propagate, the whole prompt is treated as a single body.
    """
    m = _PREFIX_RE.match(prompt)
    if not m:
        return None, prompt
    return m.group(1).strip(), prompt[m.end():]


def split_sentences_outside_quotes(text: str) -> List[str]:
    """Split ``text`` into sentences, ignoring terminators inside quotes.

    A "sentence" here is a span ending in ``.``/``!``/``?`` (optionally followed
    by a closing quote) at the top level — i.e. not inside an open ``"..."`` or
    ``'...'`` pair. Empty / whitespace-only fragments are dropped.

    Examples:
        >>> split_sentences_outside_quotes('He says, "Hi, how are you?" Then leaves.')
        ['He says, "Hi, how are you?"', 'Then leaves.']
    """
    sentences: List[str] = []
    buf: List[str] = []
    in_double = False
    in_single = False
    i = 0
    n = len(text)
    while i < n:
        ch = text[i]
        buf.append(ch)

        if ch == '"' and not in_single:
            was_inside = in_double
            in_double = not in_double
            # Treat the *closing* quote as a sentence boundary if the last
            # meaningful char inside it was a terminator: ``...how are you?"``.
            if was_inside and len(buf) >= 2 and buf[-2] in ".!?":
                # Boundary requires whitespace / end-of-string after.
                if i + 1 >= n or text[i + 1].isspace():
                    sentence = "".join(buf).strip()
                    if sentence:
                        sentences.append(sentence)
                    buf = []
                    i += 1
                    continue

        elif ch == "'" and not in_double:
            # Apostrophes inside a word (don't, it's) are not quote toggles.
            prev = text[i - 1] if i > 0 else " "
            nxt = text[i + 1] if i + 1 < n else " "
            if not (prev.isalpha() and nxt.isalpha()):
                in_single = not in_single

        elif ch in ".!?" and not in_double and not in_single:
            # Greedily eat trailing closing quotes / punctuation.
            j = i + 1
            while j < n and text[j] in '."\')]':
                buf.append(text[j])
                if text[j] == '"':
                    in_double = not in_double  # closing quote toggle
                j += 1
            if j >= n or text[j].isspace():
                sentence = "".join(buf).strip()
                if sentence:
                    sentences.append(sentence)
                buf = []
                i = j
                continue
        i += 1

    tail = "".join(buf).strip()
    if tail:
        sentences.append(tail)
    return sentences


def _assemble(prefix: Optional[str], sentences: List[str]) -> str:
    body = " ".join(s.strip() for s in sentences if s.strip())
    if not prefix:
        return body
    # Re-attach prefix in the canonical "persona, body" form. If the first
    # sentence already starts with a stage direction (no opening quote), drop
    # the comma + use a period so the syntax reads naturally.
    if body.lstrip().startswith(("'", '"')):
        return f"{prefix}, {body}"
    return f"{prefix}. {body}"


def chunk_prompt_for_duration(
    prompt: str,
    max_duration_s: float = 45.0,
    target_duration_s: float = 37.0,
    duration_multiplier: float = 1.1,
) -> List[PromptChunk]:
    """Split ``prompt`` into <= ``max_duration_s`` chunks.

    Args:
        prompt: Full scene prompt (DramaBox format or plain text).
        max_duration_s: Hard cap per chunk; we never emit a chunk whose
            estimator output (after ``duration_multiplier``) exceeds this.
        target_duration_s: Soft cap; we close the current chunk when adding
            the next sentence would push it past this. Leaving 5-10 s of
            headroom below ``max_duration_s`` keeps us safe against the
            estimator under-shooting by ~10-15% on action-heavy prompts.
        duration_multiplier: Same breathing-room multiplier the inference
            server applies in ``estimate_duration``; matches the per-chunk
            target the model is actually asked to generate.

    Returns:
        List of :class:`PromptChunk`. Single-chunk prompts return a 1-element
        list with the original prompt unchanged.
    """
    from duration_estimator import estimate_speech_duration

    def _est(t: str) -> float:
        return estimate_speech_duration(t) * duration_multiplier

    total = _est(prompt)
    if total <= max_duration_s:
        return [PromptChunk(text=prompt, est_duration_s=total)]

    prefix, body = extract_speaker_prefix(prompt)
    sentences = split_sentences_outside_quotes(body)
    if not sentences:
        # Degenerate: no sentence boundaries. Fall back to whitespace-token
        # chunking so we still produce SOMETHING under the cap.
        sentences = body.split()

    chunks: List[PromptChunk] = []
    current: List[str] = []
    current_dur = 0.0

    for sent in sentences:
        candidate = _assemble(prefix, current + [sent])
        cand_dur = _est(candidate)

        if current and cand_dur > target_duration_s:
            # Close the current chunk before adding this sentence.
            assembled = _assemble(prefix, current)
            chunks.append(PromptChunk(text=assembled, est_duration_s=_est(assembled)))
            current = [sent]
            current_dur = _est(_assemble(prefix, current))
        else:
            current.append(sent)
            current_dur = cand_dur

        # Pathological case: a single sentence whose estimator output is
        # already past max_duration_s. Emit it on its own and let downstream
        # generate() truncate the request at the model's hard limit; the user
        # gets a degraded but non-crashing result instead of an exception.
        if len(current) == 1 and current_dur > max_duration_s:
            solo = _assemble(prefix, current)
            chunks.append(PromptChunk(text=solo, est_duration_s=current_dur))
            current = []
            current_dur = 0.0

    if current:
        assembled = _assemble(prefix, current)
        chunks.append(PromptChunk(text=assembled, est_duration_s=_est(assembled)))

    return chunks