File size: 6,268 Bytes
cdc4405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright (c) 2026 Scenema AI
# https://scenema.ai
# SPDX-License-Identifier: MIT

"""Inference orchestration for Scenema Audio.

Generates audio for planned chunks with A2V voice conditioning between
chunks and concatenates the results. A2V reference from each chunk's tail
guides the next chunk toward a consistent voice, which SeedVC then
polishes for exact identity matching.
"""

import logging

import numpy as np

from .audio_utils import normalize_volume, trim_silence
from .chunker import ChunkSpec
from .engine import AudioEngine, AudioResult
from .whisper_aligner import validate_text

logger = logging.getLogger(__name__)

REF_TAIL_SECONDS = 3.0
MAX_RETRIES = 3
RETRY_DURATION_FACTOR = 1.3
MIN_WORD_MATCH_RATIO = 0.90


def generate_chunks(
    engine: AudioEngine,
    chunks: list[ChunkSpec],
    ref_latent=None,
    ref_duration_s: float = REF_TAIL_SECONDS,
    validate: bool = False,
    min_match_ratio: float = MIN_WORD_MATCH_RATIO,
    anchor_ref: bool = False,
) -> list[AudioResult]:
    """Generate audio for all chunks with A2V voice conditioning.

    Each chunk gets its own Gemma encode (since each has different text).
    The tail of each chunk's audio is encoded via Audio VAE and used as
    A2V reference for the next chunk, guiding voice consistency. SeedVC
    is applied afterward by the processor for exact identity matching.

    Args:
        engine: AudioEngine instance
        chunks: List of ChunkSpec from plan_chunks()
        ref_latent: Initial reference latent (from user-provided voice URL)
        ref_duration_s: Seconds of tail audio to use as A2V reference
        validate: If True, run Whisper validation with retry loop.
            If False (default), generate once without validation.
        anchor_ref: If True, every chunk uses ref_latent instead of
            chaining from the previous chunk's tail. Keeps voice
            anchored to the external reference.
    """
    results: list[AudioResult] = []

    for i, chunk in enumerate(chunks):
        label = "with ref" if ref_latent is not None else "no ref"
        logger.info(
            "Chunk %d/%d (%s, %.1fs): %s",
            i + 1,
            len(chunks),
            label,
            chunk.duration_s,
            chunk.expected_text[:60] + ("..." if len(chunk.expected_text) > 60 else ""),
        )

        # Gemma encode once per chunk (reused across retries)
        logger.info("Compiled prompt: %s", chunk.compiled_prompt)
        vc, ac = engine.encode_text(chunk.compiled_prompt)

        duration = chunk.duration_s
        seed = chunk.seed

        if not validate:
            # Single generation, no whisper validation
            result = engine.generate(vc, ac, duration, seed, ref_latent=ref_latent)
            best_result = result
        else:
            # Validation retry loop with whisper
            best_result = None
            best_ratio = -1.0

            for attempt in range(MAX_RETRIES + 1):
                result = engine.generate(vc, ac, duration, seed, ref_latent=ref_latent)

                passed, transcribed, ratio = validate_text(
                    result.waveform_np,
                    result.sample_rate,
                    chunk.expected_text,
                    language=chunk.language,
                    min_word_ratio=min_match_ratio,
                )

                if ratio > best_ratio:
                    best_result = result
                    best_ratio = ratio

                if passed:
                    logger.info(
                        "  Chunk %d validated: %.0f%% word match",
                        i + 1,
                        ratio * 100,
                    )
                    break

                if attempt < MAX_RETRIES:
                    duration = min(duration * RETRY_DURATION_FACTOR, 20.0)
                    seed += 1
                    logger.info(
                        "  Chunk %d retry %d: %.0f%% match, extending to %.1fs, seed=%d",
                        i + 1,
                        attempt + 1,
                        ratio * 100,
                        duration,
                        seed,
                    )
                else:
                    logger.warning(
                        "  Chunk %d: best %.0f%% match after %d retries, accepting",
                        i + 1,
                        best_ratio * 100,
                        MAX_RETRIES,
                    )

        results.append(best_result)

        # A2V: use tail of this chunk as reference for the next
        # In anchor mode, keep using the original ref_latent for every chunk
        if i < len(chunks) - 1 and not anchor_ref:
            tail_samples = int(ref_duration_s * result.sample_rate)
            tail_wav = result.waveform_np[-tail_samples:]
            ref_latent = engine.encode_reference(tail_wav, result.sample_rate)

    return results


def concatenate_chunks(
    results: list[AudioResult],
    trim: bool = True,
    normalize: bool = True,
) -> tuple[np.ndarray, int]:
    """Concatenate audio chunks with silence trimming and volume normalization.

    Trims excess silence from chunk boundaries and normalizes volume
    per-chunk to ensure consistent loudness across the full output.
    Chunks are hard-concatenated (no crossfade).

    Args:
        results: List of AudioResult from generate_chunks().
        trim: Whether to trim silence from chunk boundaries.
        normalize: Whether to normalize volume per chunk.

    Returns:
        Tuple of (concatenated waveform numpy array, sample_rate).
    """
    if not results:
        raise ValueError("No chunks to concatenate")

    sr = results[0].sample_rate
    processed: list[np.ndarray] = []

    for i, r in enumerate(results):
        w = r.waveform_np
        if trim:
            w = trim_silence(w, sr, max_silence=0.5)
        if normalize:
            w = normalize_volume(w, sr)
        processed.append(w)
        logger.debug(
            "Chunk %d: %.1fs -> %.1fs",
            i,
            r.duration_s,
            w.shape[0] / sr,
        )

    result = np.concatenate(processed, axis=0)
    logger.info(
        "Concatenated: %.1fs from %d chunks", result.shape[0] / sr, len(processed)
    )
    return result, sr