| """Audio reference conditioning item for IC-LoRA voice cloning.""" |
|
|
| import torch |
|
|
| from ltx_core.components.patchifiers import AudioPatchifier |
| from ltx_core.conditioning.item import ConditioningItem |
| from ltx_core.tools import AudioLatentTools |
| from ltx_core.types import AudioLatentShape, LatentState |
|
|
|
|
| class AudioConditionByReferenceLatent(ConditioningItem): |
| """Conditions audio generation on a reference audio latent for voice cloning. |
| |
| Mirrors VideoConditionByReferenceLatent but for audio: |
| - Patchifies reference latent [B, C, T, F] -> [B, ref_T, 128] |
| - Computes 1D temporal positions via AudioPatchifier |
| - Sets denoise_mask = 1.0 - strength (strength=1.0 -> mask=0 -> frozen) |
| - Builds ASYMMETRIC attention mask: target->ref=1 (attend), ref->target=0 (read-only) |
| - APPENDS ref tokens to END of latent sequence (IC-LoRA pattern) |
| - Uses OVERLAPPING positions (same coordinate space) so RoPE doesn't |
| decay target->ref attention. The asymmetric mask provides the structural |
| signal that ref tokens are conditioning, not reconstruction targets. |
| |
| Args: |
| latent: Reference audio latent [B, C, T, F] (pre-VAE-encoded). |
| strength: Conditioning strength. 1.0 = full (ref kept clean), |
| 0.0 = none (ref fully denoised). Default 1.0. |
| """ |
|
|
| def __init__(self, latent: torch.Tensor, strength: float = 1.0): |
| self.latent = latent |
| self.strength = strength |
|
|
| def apply_to( |
| self, |
| latent_state: LatentState, |
| latent_tools: AudioLatentTools, |
| ) -> LatentState: |
| """Append reference audio tokens with positions and attention mask.""" |
| tokens = latent_tools.patchifier.patchify(self.latent) |
|
|
| |
| |
| |
| |
| ref_shape = AudioLatentShape( |
| batch=self.latent.shape[0], |
| channels=self.latent.shape[1], |
| frames=self.latent.shape[2], |
| mel_bins=self.latent.shape[3], |
| ) |
| positions = latent_tools.patchifier.get_patch_grid_bounds( |
| output_shape=ref_shape, |
| device=self.latent.device, |
| ) |
| |
| positions = positions + 0.5 |
|
|
| |
| denoise_mask = torch.full( |
| size=(*tokens.shape[:2], 1), |
| fill_value=1.0 - self.strength, |
| device=self.latent.device, |
| dtype=torch.float32, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| batch_size = tokens.shape[0] |
| num_target = latent_state.latent.shape[1] |
| num_ref = tokens.shape[1] |
| total = num_target + num_ref |
|
|
| |
| |
| mask = torch.zeros( |
| (batch_size, total, total), |
| device=self.latent.device, |
| dtype=torch.float32, |
| ) |
|
|
| |
| if latent_state.attention_mask is not None: |
| mask[:, :num_target, :num_target] = latent_state.attention_mask |
| else: |
| mask[:, :num_target, :num_target] = 1.0 |
|
|
| |
| mask[:, :num_target, num_target:] = 1.0 |
|
|
| |
| |
|
|
| |
| mask[:, num_target:, num_target:] = 1.0 |
|
|
| return LatentState( |
| latent=torch.cat([latent_state.latent, tokens], dim=1), |
| denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1), |
| positions=torch.cat([latent_state.positions, positions], dim=2), |
| clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1), |
| attention_mask=mask, |
| ) |
|
|