"""TemporalFrameBuffer — ring buffer for video latents with HCA compression. Stores the last N video latents (local) and maintains a compressed long-range cache via TernaryScaleTensor projection. Used for conditioning video generation on previous time steps. Latent shape: [B, C, H', W'] where C=OPEN_SORA_LATENT_CHANNELS=4, H'=VIDEO_HEIGHT=32, W'=VIDEO_WIDTH=32. Each "latent" is one 4-frame chunk. """ import torch import torch.nn as nn from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType from .ring_buffer import GPURingBuffer from ..config import FRAME_BUFFER_LOCAL_SIZE, FRAME_BUFFER_CACHE_STRIDE, \ OPEN_SORA_LATENT_CHANNELS, VIDEO_HEIGHT, VIDEO_WIDTH class TemporalFrameBuffer(nn.Module): def __init__(self, local_size=FRAME_BUFFER_LOCAL_SIZE, cache_stride=FRAME_BUFFER_CACHE_STRIDE, latent_channels=OPEN_SORA_LATENT_CHANNELS, height=VIDEO_HEIGHT, width=VIDEO_WIDTH, tscale_type=TScaleType.T32): super().__init__() self.latent_channels = latent_channels self.spatial_dim = height * width self.latent_flat_dim = latent_channels * self.spatial_dim self.local = GPURingBuffer( max_size=local_size, dtype=torch.float32, dim=self.latent_flat_dim, ) self.compress_proj = TernaryScaleTensor( self.latent_flat_dim, self.latent_flat_dim // 4, tscale_type=tscale_type, ) self.compressed_cache = [] self.cache_stride = cache_stride self._frames_since_compress = 0 def append(self, latent): B = latent.shape[0] flat = latent.reshape(B, -1) self.local.append(flat) self._frames_since_compress += 1 if self._frames_since_compress >= self.cache_stride: compressed = self.compress_proj(flat) self.compressed_cache.append(compressed.detach()) self._frames_since_compress = 0 def get_local(self, n=None): n = n or self.local.max_size result = self.local.get_last_n(n) if result.dim() == 0 or result.shape[0] == 0: return torch.zeros(0, 1, self.latent_flat_dim) if result.dim() == 1: result = result.unsqueeze(0) return result def get_compressed_cache(self): if not self.compressed_cache: return torch.zeros(0, 1, self.latent_flat_dim // 4) return torch.stack(self.compressed_cache, dim=0) def get_conditioning(self, n_local=None): return { "local": self.get_local(n_local), "compressed": self.get_compressed_cache(), } def reset(self): self.local.reset() self.compressed_cache = [] self._frames_since_compress = 0