Manmay Nakhashi commited on
Commit
8018d88
·
1 Parent(s): 96ef84a

Drop video_connector + video_aggregate_embed BEFORE .to(device)

Browse files

Required for the trimmed audio-components.safetensors (video parts
stripped on disk) — without this patch the missing tensors stay on
the meta device and .to('cuda') errors with 'Cannot copy out of meta
tensor'. Patch was applied locally for the file-trim work but never
landed on the Space; this push restores the Space build.

Files changed (1) hide show
  1. ltx2/ltx_pipelines/utils/blocks.py +23 -7
ltx2/ltx_pipelines/utils/blocks.py CHANGED
@@ -363,26 +363,36 @@ class PromptEncoder:
363
  self._warm_text_encoder = self._text_encoder_builder.build(
364
  device=self._device, dtype=self._dtype
365
  ).eval()
366
- self._warm_embeddings_processor = self._embeddings_processor_builder.build(
367
  device=self._device, dtype=self._dtype
368
- ).to(self._device).eval()
369
 
370
- # Audio-only mode: delete video components to free ~4.8GB VRAM
371
- if audio_only and self._warm_embeddings_processor is not None:
 
 
 
 
372
  import logging as _log
373
- ep = self._warm_embeddings_processor
374
  freed = 0
375
 
376
  # 1. Replace video_connector with None and patch create_embeddings
377
  if ep.video_connector is not None:
378
- freed += sum(p.numel() * p.element_size() for p in ep.video_connector.parameters())
 
 
 
379
  del ep.video_connector
380
  ep.video_connector = None
381
 
382
  # 2. Replace video_aggregate_embed with a dummy that returns zeros
383
  fe = ep.feature_extractor
384
  if hasattr(fe, 'video_aggregate_embed') and fe.video_aggregate_embed is not None:
385
- freed += sum(p.numel() * p.element_size() for p in fe.video_aggregate_embed.parameters())
 
 
 
386
  out_features = fe.video_aggregate_embed.out_features
387
  del fe.video_aggregate_embed
388
  # Dummy that returns zeros with correct shape
@@ -395,6 +405,12 @@ class PromptEncoder:
395
  device=x.device, dtype=x.dtype)
396
  fe.video_aggregate_embed = _DummyVideoEmbed(out_features)
397
 
 
 
 
 
 
 
398
  # 3. Patch create_embeddings to skip video connector
399
  _orig_create = ep.create_embeddings
400
  def _audio_only_create(video_features, audio_features, additive_attention_mask,
 
363
  self._warm_text_encoder = self._text_encoder_builder.build(
364
  device=self._device, dtype=self._dtype
365
  ).eval()
366
+ built_ep = self._embeddings_processor_builder.build(
367
  device=self._device, dtype=self._dtype
368
+ )
369
 
370
+ # Audio-only mode: delete video components BEFORE .to(device).
371
+ # This both frees ~4.8GB VRAM at load time and lets us strip
372
+ # text_embedding_projection.video_aggregate_embed.* from the
373
+ # checkpoint on disk (otherwise those tensors stay on the meta
374
+ # device and .to(device) errors with "cannot copy out of meta").
375
+ if audio_only:
376
  import logging as _log
377
+ ep = built_ep
378
  freed = 0
379
 
380
  # 1. Replace video_connector with None and patch create_embeddings
381
  if ep.video_connector is not None:
382
+ try:
383
+ freed += sum(p.numel() * p.element_size() for p in ep.video_connector.parameters() if not p.is_meta)
384
+ except Exception:
385
+ pass
386
  del ep.video_connector
387
  ep.video_connector = None
388
 
389
  # 2. Replace video_aggregate_embed with a dummy that returns zeros
390
  fe = ep.feature_extractor
391
  if hasattr(fe, 'video_aggregate_embed') and fe.video_aggregate_embed is not None:
392
+ try:
393
+ freed += sum(p.numel() * p.element_size() for p in fe.video_aggregate_embed.parameters() if not p.is_meta)
394
+ except Exception:
395
+ pass
396
  out_features = fe.video_aggregate_embed.out_features
397
  del fe.video_aggregate_embed
398
  # Dummy that returns zeros with correct shape
 
405
  device=x.device, dtype=x.dtype)
406
  fe.video_aggregate_embed = _DummyVideoEmbed(out_features)
407
 
408
+ # Now move the (post-strip) module onto the target device.
409
+ self._warm_embeddings_processor = built_ep.to(self._device).eval()
410
+
411
+ if audio_only and self._warm_embeddings_processor is not None:
412
+ ep = self._warm_embeddings_processor
413
+
414
  # 3. Patch create_embeddings to skip video connector
415
  _orig_create = ep.create_embeddings
416
  def _audio_only_create(video_features, audio_features, additive_attention_mask,