dagloop5 commited on
Commit
c43c959
·
verified ·
1 Parent(s): 2470e81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -56
app.py CHANGED
@@ -65,7 +65,6 @@ from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTIL
65
  from ltx_pipelines.utils.helpers import (
66
  cleanup_memory,
67
  combined_image_conditionings,
68
- denoise_audio_video,
69
  denoise_video_only,
70
  encode_prompts,
71
  simple_denoising_func,
@@ -103,7 +102,7 @@ RESOLUTIONS = {
103
 
104
 
105
  class LTX23DistilledA2VPipeline(DistilledPipeline):
106
- """DistilledPipeline: single stage, full resolution, 8 steps, with optional audio."""
107
 
108
  def __call__(
109
  self,
@@ -118,7 +117,20 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
118
  tiling_config: TilingConfig | None = None,
119
  enhance_prompt: bool = False,
120
  ):
 
121
  print(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  generator = torch.Generator(device=self.device).manual_seed(seed)
124
  noiser = GaussianNoiser(generator=generator)
@@ -133,41 +145,32 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
133
  )
134
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
135
 
136
- # Audio encoding only runs if audio is provided
137
- encoded_audio_latent = None
138
- original_audio = None
139
- if audio_path is not None:
140
- video_duration = num_frames / frame_rate
141
- decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
142
- if decoded_audio is None:
143
- raise ValueError(f"Could not extract audio stream from {audio_path}")
144
-
145
- encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
146
- audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
147
- expected_frames = audio_shape.frames
148
- actual_frames = encoded_audio_latent.shape[2]
149
-
150
- if actual_frames > expected_frames:
151
- encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
152
- elif actual_frames < expected_frames:
153
- pad = torch.zeros(
154
- encoded_audio_latent.shape[0],
155
- encoded_audio_latent.shape[1],
156
- expected_frames - actual_frames,
157
- encoded_audio_latent.shape[3],
158
- device=encoded_audio_latent.device,
159
- dtype=encoded_audio_latent.dtype,
160
- )
161
- encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
162
-
163
- original_audio = Audio(
164
- waveform=decoded_audio.waveform.squeeze(0),
165
- sampling_rate=decoded_audio.sampling_rate,
166
  )
 
167
 
168
  video_encoder = self.model_ledger.video_encoder()
169
  transformer = self.model_ledger.transformer()
170
- sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
171
 
172
  def denoising_loop(sigmas, video_state, audio_state, stepper):
173
  return euler_denoising_loop(
@@ -182,26 +185,26 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
182
  ),
183
  )
184
 
185
- output_shape = VideoPixelShape(
186
  batch=1,
187
  frames=num_frames,
188
- width=width,
189
- height=height,
190
  fps=frame_rate,
191
  )
192
- conditionings = combined_image_conditionings(
193
  images=images,
194
- height=output_shape.height,
195
- width=output_shape.width,
196
  video_encoder=video_encoder,
197
  dtype=dtype,
198
  device=self.device,
199
  )
200
- video_state, audio_state = denoise_audio_video(
201
- output_shape=output_shape,
202
- conditionings=conditionings,
203
  noiser=noiser,
204
- sigmas=sigmas,
205
  stepper=stepper,
206
  denoising_loop_fn=denoising_loop,
207
  components=self.pipeline_components,
@@ -210,6 +213,39 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
210
  initial_audio_latent=encoded_audio_latent,
211
  )
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  torch.cuda.synchronize()
214
  del transformer
215
  del video_encoder
@@ -221,19 +257,11 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
221
  tiling_config,
222
  generator,
223
  )
224
-
225
- # If audio was provided as input, return it as-is (higher fidelity than decoded)
226
- # If no audio input, decode the generated audio latent from the denoising
227
- if original_audio is not None:
228
- return decoded_video, original_audio
229
- else:
230
- from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
231
- generated_audio = vae_decode_audio(
232
- audio_state.latent,
233
- self.model_ledger.audio_decoder(),
234
- self.model_ledger.vocoder(),
235
- )
236
- return decoded_video, generated_audio
237
 
238
 
239
  # Model repos
 
65
  from ltx_pipelines.utils.helpers import (
66
  cleanup_memory,
67
  combined_image_conditionings,
 
68
  denoise_video_only,
69
  encode_prompts,
70
  simple_denoising_func,
 
102
 
103
 
104
  class LTX23DistilledA2VPipeline(DistilledPipeline):
105
+ """DistilledPipeline with optional audio conditioning."""
106
 
107
  def __call__(
108
  self,
 
117
  tiling_config: TilingConfig | None = None,
118
  enhance_prompt: bool = False,
119
  ):
120
+ # Standard path when no audio input is provided.
121
  print(prompt)
122
+ if audio_path is None:
123
+ return super().__call__(
124
+ prompt=prompt,
125
+ seed=seed,
126
+ height=height,
127
+ width=width,
128
+ num_frames=num_frames,
129
+ frame_rate=frame_rate,
130
+ images=images,
131
+ tiling_config=tiling_config,
132
+ enhance_prompt=enhance_prompt,
133
+ )
134
 
135
  generator = torch.Generator(device=self.device).manual_seed(seed)
136
  noiser = GaussianNoiser(generator=generator)
 
145
  )
146
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
147
 
148
+ video_duration = num_frames / frame_rate
149
+ decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
150
+ if decoded_audio is None:
151
+ raise ValueError(f"Could not extract audio stream from {audio_path}")
152
+
153
+ encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
154
+ audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
155
+ expected_frames = audio_shape.frames
156
+ actual_frames = encoded_audio_latent.shape[2]
157
+
158
+ if actual_frames > expected_frames:
159
+ encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
160
+ elif actual_frames < expected_frames:
161
+ pad = torch.zeros(
162
+ encoded_audio_latent.shape[0],
163
+ encoded_audio_latent.shape[1],
164
+ expected_frames - actual_frames,
165
+ encoded_audio_latent.shape[3],
166
+ device=encoded_audio_latent.device,
167
+ dtype=encoded_audio_latent.dtype,
 
 
 
 
 
 
 
 
 
 
168
  )
169
+ encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
170
 
171
  video_encoder = self.model_ledger.video_encoder()
172
  transformer = self.model_ledger.transformer()
173
+ stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
174
 
175
  def denoising_loop(sigmas, video_state, audio_state, stepper):
176
  return euler_denoising_loop(
 
185
  ),
186
  )
187
 
188
+ stage_1_output_shape = VideoPixelShape(
189
  batch=1,
190
  frames=num_frames,
191
+ width=width // 2,
192
+ height=height // 2,
193
  fps=frame_rate,
194
  )
195
+ stage_1_conditionings = combined_image_conditionings(
196
  images=images,
197
+ height=stage_1_output_shape.height,
198
+ width=stage_1_output_shape.width,
199
  video_encoder=video_encoder,
200
  dtype=dtype,
201
  device=self.device,
202
  )
203
+ video_state = denoise_video_only(
204
+ output_shape=stage_1_output_shape,
205
+ conditionings=stage_1_conditionings,
206
  noiser=noiser,
207
+ sigmas=stage_1_sigmas,
208
  stepper=stepper,
209
  denoising_loop_fn=denoising_loop,
210
  components=self.pipeline_components,
 
213
  initial_audio_latent=encoded_audio_latent,
214
  )
215
 
216
+ torch.cuda.synchronize()
217
+ cleanup_memory()
218
+
219
+ upscaled_video_latent = upsample_video(
220
+ latent=video_state.latent[:1],
221
+ video_encoder=video_encoder,
222
+ upsampler=self.model_ledger.spatial_upsampler(),
223
+ )
224
+ stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
225
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
226
+ stage_2_conditionings = combined_image_conditionings(
227
+ images=images,
228
+ height=stage_2_output_shape.height,
229
+ width=stage_2_output_shape.width,
230
+ video_encoder=video_encoder,
231
+ dtype=dtype,
232
+ device=self.device,
233
+ )
234
+ video_state = denoise_video_only(
235
+ output_shape=stage_2_output_shape,
236
+ conditionings=stage_2_conditionings,
237
+ noiser=noiser,
238
+ sigmas=stage_2_sigmas,
239
+ stepper=stepper,
240
+ denoising_loop_fn=denoising_loop,
241
+ components=self.pipeline_components,
242
+ dtype=dtype,
243
+ device=self.device,
244
+ noise_scale=stage_2_sigmas[0],
245
+ initial_video_latent=upscaled_video_latent,
246
+ initial_audio_latent=encoded_audio_latent,
247
+ )
248
+
249
  torch.cuda.synchronize()
250
  del transformer
251
  del video_encoder
 
257
  tiling_config,
258
  generator,
259
  )
260
+ original_audio = Audio(
261
+ waveform=decoded_audio.waveform.squeeze(0),
262
+ sampling_rate=decoded_audio.sampling_rate,
263
+ )
264
+ return decoded_video, original_audio
 
 
 
 
 
 
 
 
265
 
266
 
267
  # Model repos