Manmay Nakhashi commited on
Commit
c29ae29
·
1 Parent(s): 8018d88

Patch end-of-clip silence prior at latent frame 513 for >20s outputs

Browse files

Base LTX-2.3 22B was trained on audio ≤~20s and learned a strong
'clip-end silence' prior that lands at the next patchifier-aligned
latent boundary (frame 513 = 8*64+1). For longer generations this
leaks through as a ~30ms silence dip near 20.4s. Linear interpolation
of frames 512-513 between neighbours 511 and 514 removes the dip
cleanly. Only triggers when latent has >513 frames so shorter outputs
remain byte-identical.

Files changed (1) hide show
  1. src/inference_server.py +21 -1
src/inference_server.py CHANGED
@@ -259,8 +259,28 @@ class TTSServer:
259
  audio_state = audio_tools.clear_conditioning(audio_state)
260
  audio_state = audio_tools.unpatchify(audio_state)
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  t0 = time.time()
263
- decoded = self._audio_decoder(audio_state.latent)
264
  logging.info(f"Decode: {time.time()-t0:.2f}s")
265
 
266
  total = time.time() - t_total
 
259
  audio_state = audio_tools.clear_conditioning(audio_state)
260
  audio_state = audio_tools.unpatchify(audio_state)
261
 
262
+ # End-of-clip silence-prior fix.
263
+ # The base LTX-2.3 22B DiT was trained on audio clips ≤ ~20 s and
264
+ # learned a strong "clip-end silence" prior that lands on the next
265
+ # patchifier-aligned latent frame after 20 s — index 513 = 8*64+1.
266
+ # When inference produces longer audio, this prior leaks through as a
267
+ # high-norm latent burst at frame 513 (and adjacent 512), which the
268
+ # audio VAE + vocoder render as a ~30 ms hard silence dip near 20.4 s.
269
+ # Linear interpolation across the two affected frames removes the dip
270
+ # cleanly without any retraining. Only runs when the latent is long
271
+ # enough to actually contain the boundary.
272
+ latent = audio_state.latent
273
+ if latent.shape[2] > 513:
274
+ f0, f1 = 511, 514 # neighbours used for interpolation
275
+ n = f1 - f0 # = 3
276
+ patched = latent.clone()
277
+ for f in (512, 513):
278
+ t = (f - f0) / n
279
+ patched[:, :, f, :] = (1.0 - t) * latent[:, :, f0, :] + t * latent[:, :, f1, :]
280
+ latent = patched
281
+
282
  t0 = time.time()
283
+ decoded = self._audio_decoder(latent)
284
  logging.info(f"Decode: {time.time()-t0:.2f}s")
285
 
286
  total = time.time() - t_total