ARBS / arbitor /attention /frame_buffer.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""TemporalFrameBuffer — ring buffer for video latents with HCA compression.
Stores the last N video latents (local) and maintains a compressed long-range
cache via TernaryScaleTensor projection. Used for conditioning video generation
on previous time steps.
Latent shape: [B, C, H', W'] where C=OPEN_SORA_LATENT_CHANNELS=4,
H'=VIDEO_HEIGHT=32, W'=VIDEO_WIDTH=32. Each "latent" is one 4-frame chunk.
"""
import torch
import torch.nn as nn
from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType
from .ring_buffer import GPURingBuffer
from ..config import FRAME_BUFFER_LOCAL_SIZE, FRAME_BUFFER_CACHE_STRIDE, \
OPEN_SORA_LATENT_CHANNELS, VIDEO_HEIGHT, VIDEO_WIDTH
class TemporalFrameBuffer(nn.Module):
def __init__(self, local_size=FRAME_BUFFER_LOCAL_SIZE,
cache_stride=FRAME_BUFFER_CACHE_STRIDE,
latent_channels=OPEN_SORA_LATENT_CHANNELS,
height=VIDEO_HEIGHT, width=VIDEO_WIDTH,
tscale_type=TScaleType.T32):
super().__init__()
self.latent_channels = latent_channels
self.spatial_dim = height * width
self.latent_flat_dim = latent_channels * self.spatial_dim
self.local = GPURingBuffer(
max_size=local_size,
dtype=torch.float32,
dim=self.latent_flat_dim,
)
self.compress_proj = TernaryScaleTensor(
self.latent_flat_dim,
self.latent_flat_dim // 4,
tscale_type=tscale_type,
)
self.compressed_cache = []
self.cache_stride = cache_stride
self._frames_since_compress = 0
def append(self, latent):
B = latent.shape[0]
flat = latent.reshape(B, -1)
self.local.append(flat)
self._frames_since_compress += 1
if self._frames_since_compress >= self.cache_stride:
compressed = self.compress_proj(flat)
self.compressed_cache.append(compressed.detach())
self._frames_since_compress = 0
def get_local(self, n=None):
n = n or self.local.max_size
result = self.local.get_last_n(n)
if result.dim() == 0 or result.shape[0] == 0:
return torch.zeros(0, 1, self.latent_flat_dim)
if result.dim() == 1:
result = result.unsqueeze(0)
return result
def get_compressed_cache(self):
if not self.compressed_cache:
return torch.zeros(0, 1, self.latent_flat_dim // 4)
return torch.stack(self.compressed_cache, dim=0)
def get_conditioning(self, n_local=None):
return {
"local": self.get_local(n_local),
"compressed": self.get_compressed_cache(),
}
def reset(self):
self.local.reset()
self.compressed_cache = []
self._frames_since_compress = 0