File size: 2,819 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""TemporalFrameBuffer — ring buffer for video latents with HCA compression.

Stores the last N video latents (local) and maintains a compressed long-range
cache via TernaryScaleTensor projection. Used for conditioning video generation
on previous time steps.

Latent shape: [B, C, H', W'] where C=OPEN_SORA_LATENT_CHANNELS=4,
H'=VIDEO_HEIGHT=32, W'=VIDEO_WIDTH=32. Each "latent" is one 4-frame chunk.
"""
import torch
import torch.nn as nn
from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType
from .ring_buffer import GPURingBuffer
from ..config import FRAME_BUFFER_LOCAL_SIZE, FRAME_BUFFER_CACHE_STRIDE, \
    OPEN_SORA_LATENT_CHANNELS, VIDEO_HEIGHT, VIDEO_WIDTH


class TemporalFrameBuffer(nn.Module):
    def __init__(self, local_size=FRAME_BUFFER_LOCAL_SIZE,
                 cache_stride=FRAME_BUFFER_CACHE_STRIDE,
                 latent_channels=OPEN_SORA_LATENT_CHANNELS,
                 height=VIDEO_HEIGHT, width=VIDEO_WIDTH,
                 tscale_type=TScaleType.T32):
        super().__init__()
        self.latent_channels = latent_channels
        self.spatial_dim = height * width
        self.latent_flat_dim = latent_channels * self.spatial_dim

        self.local = GPURingBuffer(
            max_size=local_size,
            dtype=torch.float32,
            dim=self.latent_flat_dim,
        )

        self.compress_proj = TernaryScaleTensor(
            self.latent_flat_dim,
            self.latent_flat_dim // 4,
            tscale_type=tscale_type,
        )
        self.compressed_cache = []
        self.cache_stride = cache_stride
        self._frames_since_compress = 0

    def append(self, latent):
        B = latent.shape[0]
        flat = latent.reshape(B, -1)
        self.local.append(flat)

        self._frames_since_compress += 1
        if self._frames_since_compress >= self.cache_stride:
            compressed = self.compress_proj(flat)
            self.compressed_cache.append(compressed.detach())
            self._frames_since_compress = 0

    def get_local(self, n=None):
        n = n or self.local.max_size
        result = self.local.get_last_n(n)
        if result.dim() == 0 or result.shape[0] == 0:
            return torch.zeros(0, 1, self.latent_flat_dim)
        if result.dim() == 1:
            result = result.unsqueeze(0)
        return result

    def get_compressed_cache(self):
        if not self.compressed_cache:
            return torch.zeros(0, 1, self.latent_flat_dim // 4)
        return torch.stack(self.compressed_cache, dim=0)

    def get_conditioning(self, n_local=None):
        return {
            "local": self.get_local(n_local),
            "compressed": self.get_compressed_cache(),
        }

    def reset(self):
        self.local.reset()
        self.compressed_cache = []
        self._frames_since_compress = 0