Update app.py
Browse files
app.py
CHANGED
|
@@ -65,7 +65,6 @@ from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTIL
|
|
| 65 |
from ltx_pipelines.utils.helpers import (
|
| 66 |
cleanup_memory,
|
| 67 |
combined_image_conditionings,
|
| 68 |
-
denoise_audio_video,
|
| 69 |
denoise_video_only,
|
| 70 |
encode_prompts,
|
| 71 |
simple_denoising_func,
|
|
@@ -103,7 +102,7 @@ RESOLUTIONS = {
|
|
| 103 |
|
| 104 |
|
| 105 |
class LTX23DistilledA2VPipeline(DistilledPipeline):
|
| 106 |
-
"""DistilledPipeline
|
| 107 |
|
| 108 |
def __call__(
|
| 109 |
self,
|
|
@@ -118,7 +117,20 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 118 |
tiling_config: TilingConfig | None = None,
|
| 119 |
enhance_prompt: bool = False,
|
| 120 |
):
|
|
|
|
| 121 |
print(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 124 |
noiser = GaussianNoiser(generator=generator)
|
|
@@ -133,41 +145,32 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 133 |
)
|
| 134 |
video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
encoded_audio_latent
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
expected_frames - actual_frames,
|
| 157 |
-
encoded_audio_latent.shape[3],
|
| 158 |
-
device=encoded_audio_latent.device,
|
| 159 |
-
dtype=encoded_audio_latent.dtype,
|
| 160 |
-
)
|
| 161 |
-
encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
|
| 162 |
-
|
| 163 |
-
original_audio = Audio(
|
| 164 |
-
waveform=decoded_audio.waveform.squeeze(0),
|
| 165 |
-
sampling_rate=decoded_audio.sampling_rate,
|
| 166 |
)
|
|
|
|
| 167 |
|
| 168 |
video_encoder = self.model_ledger.video_encoder()
|
| 169 |
transformer = self.model_ledger.transformer()
|
| 170 |
-
|
| 171 |
|
| 172 |
def denoising_loop(sigmas, video_state, audio_state, stepper):
|
| 173 |
return euler_denoising_loop(
|
|
@@ -182,26 +185,26 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 182 |
),
|
| 183 |
)
|
| 184 |
|
| 185 |
-
|
| 186 |
batch=1,
|
| 187 |
frames=num_frames,
|
| 188 |
-
width=width,
|
| 189 |
-
height=height,
|
| 190 |
fps=frame_rate,
|
| 191 |
)
|
| 192 |
-
|
| 193 |
images=images,
|
| 194 |
-
height=
|
| 195 |
-
width=
|
| 196 |
video_encoder=video_encoder,
|
| 197 |
dtype=dtype,
|
| 198 |
device=self.device,
|
| 199 |
)
|
| 200 |
-
video_state
|
| 201 |
-
output_shape=
|
| 202 |
-
conditionings=
|
| 203 |
noiser=noiser,
|
| 204 |
-
sigmas=
|
| 205 |
stepper=stepper,
|
| 206 |
denoising_loop_fn=denoising_loop,
|
| 207 |
components=self.pipeline_components,
|
|
@@ -210,6 +213,39 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 210 |
initial_audio_latent=encoded_audio_latent,
|
| 211 |
)
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
torch.cuda.synchronize()
|
| 214 |
del transformer
|
| 215 |
del video_encoder
|
|
@@ -221,19 +257,11 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 221 |
tiling_config,
|
| 222 |
generator,
|
| 223 |
)
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
else:
|
| 230 |
-
from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
|
| 231 |
-
generated_audio = vae_decode_audio(
|
| 232 |
-
audio_state.latent,
|
| 233 |
-
self.model_ledger.audio_decoder(),
|
| 234 |
-
self.model_ledger.vocoder(),
|
| 235 |
-
)
|
| 236 |
-
return decoded_video, generated_audio
|
| 237 |
|
| 238 |
|
| 239 |
# Model repos
|
|
|
|
| 65 |
from ltx_pipelines.utils.helpers import (
|
| 66 |
cleanup_memory,
|
| 67 |
combined_image_conditionings,
|
|
|
|
| 68 |
denoise_video_only,
|
| 69 |
encode_prompts,
|
| 70 |
simple_denoising_func,
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
class LTX23DistilledA2VPipeline(DistilledPipeline):
|
| 105 |
+
"""DistilledPipeline with optional audio conditioning."""
|
| 106 |
|
| 107 |
def __call__(
|
| 108 |
self,
|
|
|
|
| 117 |
tiling_config: TilingConfig | None = None,
|
| 118 |
enhance_prompt: bool = False,
|
| 119 |
):
|
| 120 |
+
# Standard path when no audio input is provided.
|
| 121 |
print(prompt)
|
| 122 |
+
if audio_path is None:
|
| 123 |
+
return super().__call__(
|
| 124 |
+
prompt=prompt,
|
| 125 |
+
seed=seed,
|
| 126 |
+
height=height,
|
| 127 |
+
width=width,
|
| 128 |
+
num_frames=num_frames,
|
| 129 |
+
frame_rate=frame_rate,
|
| 130 |
+
images=images,
|
| 131 |
+
tiling_config=tiling_config,
|
| 132 |
+
enhance_prompt=enhance_prompt,
|
| 133 |
+
)
|
| 134 |
|
| 135 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 136 |
noiser = GaussianNoiser(generator=generator)
|
|
|
|
| 145 |
)
|
| 146 |
video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
|
| 147 |
|
| 148 |
+
video_duration = num_frames / frame_rate
|
| 149 |
+
decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
|
| 150 |
+
if decoded_audio is None:
|
| 151 |
+
raise ValueError(f"Could not extract audio stream from {audio_path}")
|
| 152 |
+
|
| 153 |
+
encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
|
| 154 |
+
audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
|
| 155 |
+
expected_frames = audio_shape.frames
|
| 156 |
+
actual_frames = encoded_audio_latent.shape[2]
|
| 157 |
+
|
| 158 |
+
if actual_frames > expected_frames:
|
| 159 |
+
encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
|
| 160 |
+
elif actual_frames < expected_frames:
|
| 161 |
+
pad = torch.zeros(
|
| 162 |
+
encoded_audio_latent.shape[0],
|
| 163 |
+
encoded_audio_latent.shape[1],
|
| 164 |
+
expected_frames - actual_frames,
|
| 165 |
+
encoded_audio_latent.shape[3],
|
| 166 |
+
device=encoded_audio_latent.device,
|
| 167 |
+
dtype=encoded_audio_latent.dtype,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
)
|
| 169 |
+
encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
|
| 170 |
|
| 171 |
video_encoder = self.model_ledger.video_encoder()
|
| 172 |
transformer = self.model_ledger.transformer()
|
| 173 |
+
stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
|
| 174 |
|
| 175 |
def denoising_loop(sigmas, video_state, audio_state, stepper):
|
| 176 |
return euler_denoising_loop(
|
|
|
|
| 185 |
),
|
| 186 |
)
|
| 187 |
|
| 188 |
+
stage_1_output_shape = VideoPixelShape(
|
| 189 |
batch=1,
|
| 190 |
frames=num_frames,
|
| 191 |
+
width=width // 2,
|
| 192 |
+
height=height // 2,
|
| 193 |
fps=frame_rate,
|
| 194 |
)
|
| 195 |
+
stage_1_conditionings = combined_image_conditionings(
|
| 196 |
images=images,
|
| 197 |
+
height=stage_1_output_shape.height,
|
| 198 |
+
width=stage_1_output_shape.width,
|
| 199 |
video_encoder=video_encoder,
|
| 200 |
dtype=dtype,
|
| 201 |
device=self.device,
|
| 202 |
)
|
| 203 |
+
video_state = denoise_video_only(
|
| 204 |
+
output_shape=stage_1_output_shape,
|
| 205 |
+
conditionings=stage_1_conditionings,
|
| 206 |
noiser=noiser,
|
| 207 |
+
sigmas=stage_1_sigmas,
|
| 208 |
stepper=stepper,
|
| 209 |
denoising_loop_fn=denoising_loop,
|
| 210 |
components=self.pipeline_components,
|
|
|
|
| 213 |
initial_audio_latent=encoded_audio_latent,
|
| 214 |
)
|
| 215 |
|
| 216 |
+
torch.cuda.synchronize()
|
| 217 |
+
cleanup_memory()
|
| 218 |
+
|
| 219 |
+
upscaled_video_latent = upsample_video(
|
| 220 |
+
latent=video_state.latent[:1],
|
| 221 |
+
video_encoder=video_encoder,
|
| 222 |
+
upsampler=self.model_ledger.spatial_upsampler(),
|
| 223 |
+
)
|
| 224 |
+
stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
|
| 225 |
+
stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
| 226 |
+
stage_2_conditionings = combined_image_conditionings(
|
| 227 |
+
images=images,
|
| 228 |
+
height=stage_2_output_shape.height,
|
| 229 |
+
width=stage_2_output_shape.width,
|
| 230 |
+
video_encoder=video_encoder,
|
| 231 |
+
dtype=dtype,
|
| 232 |
+
device=self.device,
|
| 233 |
+
)
|
| 234 |
+
video_state = denoise_video_only(
|
| 235 |
+
output_shape=stage_2_output_shape,
|
| 236 |
+
conditionings=stage_2_conditionings,
|
| 237 |
+
noiser=noiser,
|
| 238 |
+
sigmas=stage_2_sigmas,
|
| 239 |
+
stepper=stepper,
|
| 240 |
+
denoising_loop_fn=denoising_loop,
|
| 241 |
+
components=self.pipeline_components,
|
| 242 |
+
dtype=dtype,
|
| 243 |
+
device=self.device,
|
| 244 |
+
noise_scale=stage_2_sigmas[0],
|
| 245 |
+
initial_video_latent=upscaled_video_latent,
|
| 246 |
+
initial_audio_latent=encoded_audio_latent,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
torch.cuda.synchronize()
|
| 250 |
del transformer
|
| 251 |
del video_encoder
|
|
|
|
| 257 |
tiling_config,
|
| 258 |
generator,
|
| 259 |
)
|
| 260 |
+
original_audio = Audio(
|
| 261 |
+
waveform=decoded_audio.waveform.squeeze(0),
|
| 262 |
+
sampling_rate=decoded_audio.sampling_rate,
|
| 263 |
+
)
|
| 264 |
+
return decoded_video, original_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
|
| 267 |
# Model repos
|