| """TemporalFrameBuffer — ring buffer for video latents with HCA compression. |
| |
| Stores the last N video latents (local) and maintains a compressed long-range |
| cache via TernaryScaleTensor projection. Used for conditioning video generation |
| on previous time steps. |
| |
| Latent shape: [B, C, H', W'] where C=OPEN_SORA_LATENT_CHANNELS=4, |
| H'=VIDEO_HEIGHT=32, W'=VIDEO_WIDTH=32. Each "latent" is one 4-frame chunk. |
| """ |
| import torch |
| import torch.nn as nn |
| from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType |
| from .ring_buffer import GPURingBuffer |
| from ..config import FRAME_BUFFER_LOCAL_SIZE, FRAME_BUFFER_CACHE_STRIDE, \ |
| OPEN_SORA_LATENT_CHANNELS, VIDEO_HEIGHT, VIDEO_WIDTH |
|
|
|
|
| class TemporalFrameBuffer(nn.Module): |
| def __init__(self, local_size=FRAME_BUFFER_LOCAL_SIZE, |
| cache_stride=FRAME_BUFFER_CACHE_STRIDE, |
| latent_channels=OPEN_SORA_LATENT_CHANNELS, |
| height=VIDEO_HEIGHT, width=VIDEO_WIDTH, |
| tscale_type=TScaleType.T32): |
| super().__init__() |
| self.latent_channels = latent_channels |
| self.spatial_dim = height * width |
| self.latent_flat_dim = latent_channels * self.spatial_dim |
|
|
| self.local = GPURingBuffer( |
| max_size=local_size, |
| dtype=torch.float32, |
| dim=self.latent_flat_dim, |
| ) |
|
|
| self.compress_proj = TernaryScaleTensor( |
| self.latent_flat_dim, |
| self.latent_flat_dim // 4, |
| tscale_type=tscale_type, |
| ) |
| self.compressed_cache = [] |
| self.cache_stride = cache_stride |
| self._frames_since_compress = 0 |
|
|
| def append(self, latent): |
| B = latent.shape[0] |
| flat = latent.reshape(B, -1) |
| self.local.append(flat) |
|
|
| self._frames_since_compress += 1 |
| if self._frames_since_compress >= self.cache_stride: |
| compressed = self.compress_proj(flat) |
| self.compressed_cache.append(compressed.detach()) |
| self._frames_since_compress = 0 |
|
|
| def get_local(self, n=None): |
| n = n or self.local.max_size |
| result = self.local.get_last_n(n) |
| if result.dim() == 0 or result.shape[0] == 0: |
| return torch.zeros(0, 1, self.latent_flat_dim) |
| if result.dim() == 1: |
| result = result.unsqueeze(0) |
| return result |
|
|
| def get_compressed_cache(self): |
| if not self.compressed_cache: |
| return torch.zeros(0, 1, self.latent_flat_dim // 4) |
| return torch.stack(self.compressed_cache, dim=0) |
|
|
| def get_conditioning(self, n_local=None): |
| return { |
| "local": self.get_local(n_local), |
| "compressed": self.get_compressed_cache(), |
| } |
|
|
| def reset(self): |
| self.local.reset() |
| self.compressed_cache = [] |
| self._frames_since_compress = 0 |
|
|