File size: 12,690 Bytes
bad41bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3ea692
9743ceb
bad41bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a0b828
bad41bb
 
ecf81c1
 
bad41bb
 
 
 
 
 
 
 
 
 
 
bfe2ce5
 
9743ceb
bfe2ce5
9743ceb
bad41bb
9743ceb
bad41bb
 
 
 
 
 
 
 
bfe2ce5
bad41bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9743ceb
bad41bb
9743ceb
bad41bb
9743ceb
bad41bb
9743ceb
bad41bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a608b1a
 
bad41bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8505af9
bad41bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48c9e68
 
bad41bb
 
 
 
 
 
 
a098f9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bad41bb
a608b1a
 
bad41bb
 
 
 
a608b1a
 
bad41bb
 
 
 
a608b1a
 
bad41bb
 
 
a608b1a
 
 
 
 
 
 
 
 
8a9219c
a608b1a
 
 
 
 
5446da2
a608b1a
 
bad41bb
 
 
 
 
 
 
1fbebf9
9ac3e91
bad41bb
48c9e68
bad41bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b190ba6
 
bad41bb
e3ea692
bad41bb
 
 
 
 
e3ea692
bad41bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""
VOID – Video Object and Interaction Deletion
Gradio demo for Hugging Face Spaces (ZeroGPU)
"""

import os
import sys
import tempfile

import numpy as np
import torch
import torch.nn.functional as F
import imageio
import mediapy as media
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from diffusers import DDIMScheduler
from PIL import Image

# ── project imports ────────────────────────────────────────────────────────────
sys.path.insert(0, os.path.dirname(__file__))

from videox_fun.models import (
    AutoencoderKLCogVideoX,
    CogVideoXTransformer3DModel,
    T5EncoderModel,
    T5Tokenizer,
)
from videox_fun.pipeline import CogVideoXFunInpaintPipeline
from videox_fun.utils.fp8_optimization import convert_weight_dtype_wrapper
from videox_fun.utils.utils import temporal_padding

# ── constants ──────────────────────────────────────────────────────────────────
# Set these env vars in your HF Space settings, or hardcode once weights are public.
BASE_MODEL_ID  = os.environ.get("BASE_MODEL_ID", "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP")
VOID_MODEL_ID  = os.environ.get("VOID_MODEL_ID", "your-hf-username/VOID")
VOID_CKPT_FILE = "void_pass1.safetensors"

SAMPLE_SIZE  = (384, 672)   # H Γ— W
MAX_VID_LEN  = 197
TEMPORAL_WIN = 72
FPS          = 12
WEIGHT_DTYPE = torch.bfloat16
NEG_PROMPT = (
    "The video is not of a high quality, it has a low resolution. "
    "Watermark present in each frame. The background is solid. "
    "Strange body and strange trajectory. Distortion."
)

# ── model loading (once at startup, lives in CPU RAM between GPU requests) ─────
print("Loading VOID pipeline …")

HF_TOKEN = os.environ.get("HF_TOKEN")

# Download base model to local cache (custom from_pretrained needs a local path)
base_model_path = snapshot_download(repo_id=BASE_MODEL_ID, token=HF_TOKEN)

transformer = CogVideoXTransformer3DModel.from_pretrained(
    base_model_path,
    subfolder="transformer",
    low_cpu_mem_usage=True,
    torch_dtype=torch.float8_e4m3fn,  # qfloat8 to save VRAM
    use_vae_mask=True,
    stack_mask=False,
).to(WEIGHT_DTYPE)

# Load VOID Pass-1 checkpoint
ckpt_path  = hf_hub_download(repo_id=VOID_MODEL_ID, filename=VOID_CKPT_FILE, token=HF_TOKEN)
state_dict = load_file(ckpt_path)
state_dict = state_dict.get("state_dict", state_dict)

# Adapt patch_embed channels if they differ (mask-conditioning channels added)
param_name = "patch_embed.proj.weight"
if state_dict[param_name].size(1) != transformer.state_dict()[param_name].size(1):
    feat_dim   = 16 * 8  # latent_channels * feat_scale
    new_weight = transformer.state_dict()[param_name].clone()
    new_weight[:, :feat_dim]  = state_dict[param_name][:, :feat_dim]
    new_weight[:, -feat_dim:] = state_dict[param_name][:, -feat_dim:]
    state_dict[param_name] = new_weight

transformer.load_state_dict(state_dict, strict=False)

vae = AutoencoderKLCogVideoX.from_pretrained(
    base_model_path, subfolder="vae"
).to(WEIGHT_DTYPE)
tokenizer    = T5Tokenizer.from_pretrained(base_model_path, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(
    base_model_path, subfolder="text_encoder", torch_dtype=WEIGHT_DTYPE
)
scheduler = DDIMScheduler.from_pretrained(base_model_path, subfolder="scheduler")

pipeline = CogVideoXFunInpaintPipeline(
    vae=vae,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    transformer=transformer,
    scheduler=scheduler,
)
convert_weight_dtype_wrapper(transformer, WEIGHT_DTYPE)
pipeline.enable_model_cpu_offload()

print("VOID pipeline ready.")


# ── helpers ────────────────────────────────────────────────────────────────────
def load_video_tensor(path: str) -> torch.Tensor:
    """Return (1, C, T, H, W) float32 in [0, 1] resized to SAMPLE_SIZE."""
    frames = media.read_video(path)
    t = torch.from_numpy(np.array(frames))[:MAX_VID_LEN]     # (T, H, W, C)
    t = t.permute(3, 0, 1, 2).float() / 255.0                # (C, T, H, W)
    t = F.interpolate(t, SAMPLE_SIZE, mode="area").unsqueeze(0)
    return t


def load_quadmask_tensor(path: str) -> torch.Tensor:
    """
    Return (1, 1, T, H, W) float32 in [0, 1].

    Quadmask pixel values:
      0   β†’ primary object (to erase)
      63  β†’ overlap / interaction zone
      127 β†’ affected region (shadows, reflections …)
      255 β†’ background (keep)

    After quantisation the mask is inverted so 255 = "erase", 0 = "keep",
    matching the pipeline's internal convention.
    """
    frames = media.read_video(path)[:MAX_VID_LEN]
    if frames.ndim == 4:
        frames = frames[..., 0]   # take first channel, grayscale
    m = torch.from_numpy(np.array(frames)).unsqueeze(0).float()   # (1, T, H, W)
    m = F.interpolate(m, SAMPLE_SIZE, mode="area").unsqueeze(0)   # (1, 1, T, H, W)

    # Quantise to four canonical values
    m = torch.where(m <= 31,               torch.zeros_like(m),        m)
    m = torch.where((m > 31) & (m <= 95),  torch.full_like(m, 63),     m)
    m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127),    m)
    m = torch.where(m > 191,               torch.full_like(m, 255),    m)

    m = 255.0 - m   # invert
    return m / 255.0


def tensor_to_mp4(video: torch.Tensor) -> str:
    """Save (1, C, T, H, W) in [0, 1] to a temp mp4 and return the path."""
    frames = video[0].permute(1, 2, 3, 0).cpu().float().numpy()  # (T, H, W, C)
    frames = (frames * 255).clip(0, 255).astype(np.uint8)
    tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
    imageio.mimsave(tmp.name, frames, fps=FPS)
    return tmp.name


# ── inference ──────────────────────────────────────────────────────────────────
@spaces.GPU(duration=120)
def run_inpaint(
    input_video_path: str,
    mask_video_path: str,
    prompt: str,
    num_steps: int,
    guidance_scale: float,
    seed: int,
) -> str:
    if not input_video_path or not mask_video_path:
        raise gr.Error("Please upload both an input video and a quadmask video.")
    if not prompt.strip():
        raise gr.Error("Please enter a prompt describing the scene after removal.")

    generator = torch.Generator(device="cuda").manual_seed(int(seed))

    input_video = load_video_tensor(input_video_path)
    input_mask  = load_quadmask_tensor(mask_video_path)

    input_video = temporal_padding(input_video, min_length=TEMPORAL_WIN, max_length=MAX_VID_LEN)
    input_mask  = temporal_padding(input_mask,  min_length=TEMPORAL_WIN, max_length=MAX_VID_LEN)

    with torch.no_grad():
        result = pipeline(
            prompt=prompt,
            negative_prompt=NEG_PROMPT,
            height=SAMPLE_SIZE[0],
            width=SAMPLE_SIZE[1],
            num_frames=TEMPORAL_WIN,
            video=input_video,
            mask_video=input_mask,
            generator=generator,
            guidance_scale=guidance_scale,
            num_inference_steps=num_steps,
            strength=1.0,
            use_trimask=True,
            use_vae_mask=True,
            stack_mask=False,
            zero_out_mask_region=False,
        ).videos

    return tensor_to_mp4(result)


# ── Gradio UI ──────────────────────────────────────────────────────────────────
QUADMASK_EXPLAINER = """
### Quadmask format

The quadmask is a **grayscale video** where each pixel value encodes what role that region plays:

| Pixel value | Meaning |
|-------------|---------|
| **0** (black) | Primary object to remove |
| **63** (dark grey) | Overlap of primary object / affected zone |
| **127** (mid grey) | Affected region β€” shadows, reflections, new and old trajectories |
| **255** (white) | Background β€” keep as-is |

Use the **VLM-Mask-Reasoner** pipeline included in the repo to generate quadmasks automatically.
"""

SAMPLE_DIR = os.path.join(os.path.dirname(__file__), "sample")
EXAMPLES = [

    [
        os.path.join(SAMPLE_DIR, "BigBen",          "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "BigBen",          "quadmask_0.mp4"),
        "A video of London's skyline reflecting in the Thames",
        30, 1.0, 42,
    ],

    [
        os.path.join(SAMPLE_DIR, "trampoline",          "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "trampoline",          "quadmask_0.mp4"),
        "A video of an empty trampoline.",
        30, 1.0, 42,
    ],

    [
        os.path.join(SAMPLE_DIR, "spinner",          "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "spinner",          "quadmask_0.mp4"),
        "A video of two spinningtops spinning.",
        30, 1.0, 42,
    ],

    [
        os.path.join(SAMPLE_DIR, "ducky-float",          "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "ducky-float",          "quadmask_0.mp4"),
        "A video of a rubber ducky.",
        30, 1.0, 42,
    ],
    
    [
        os.path.join(SAMPLE_DIR, "lime",          "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "lime",          "quadmask_0.mp4"),
        "A lime falls on the table.",
        30, 1.0, 42,
    ],
    [
        os.path.join(SAMPLE_DIR, "moving_ball",   "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "moving_ball",   "quadmask_0.mp4"),
        "A ball rolls off the table.",
        30, 1.0, 42,
    ],
    [
        os.path.join(SAMPLE_DIR, "pillow",        "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "pillow",        "quadmask_0.mp4"),
        "Two pillows placed on the table.",
        30, 1.0, 42,
    ],
    [
        os.path.join(SAMPLE_DIR, "bowling",       "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "bowling",       "quadmask_0.mp4"),
        "Bowling pins standing on the grass.",
        30, 1.0, 42,
    ],
    [
        os.path.join(SAMPLE_DIR, "crush-can",     "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "crush-can",     "quadmask_0.mp4"),
        "A soda can on the table.",
        30, 1.0, 42,
    ],
    [
        os.path.join(SAMPLE_DIR, "toast-shmello", "input_video.mp4"),
        os.path.join(SAMPLE_DIR, "toast-shmello", "quadmask_0.mp4"),
        "A marshmallow dessert.",
        30, 1.0, 42,
    ],
]

with gr.Blocks(title="VOID – Video Object & Interaction Deletion") as demo:
    gr.Markdown(
        """
# VOID – Video Object and Interaction Deletion

[🌐 Project Page](https://void-model.github.io/) | [πŸ’» GitHub](https://github.com/Netflix/void-model)

Upload a video and its **quadmask**, enter a prompt describing the scene *after* removal,
and VOID will erase the object along with its physical interactions.

> Built on **CogVideoX-Fun-V1.5-5B** fine-tuned for interaction-aware video inpainting.
        """
    )

    with gr.Row():
        with gr.Column():
            input_video = gr.Video(label="Input video", sources=["upload"])
            mask_video  = gr.Video(label="Quadmask video", sources=["upload"])
            prompt = gr.Textbox(
                label="Prompt β€” describe the scene after removal",
                placeholder="e.g. A wooden table with nothing on it.",
                lines=2,
            )
            with gr.Accordion("Advanced settings", open=False):
                num_steps      = gr.Slider(10, 50, value=30, step=1,    label="Inference steps")
                guidance_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="Guidance scale")
                seed           = gr.Number(value=42, label="Seed", precision=0)
            run_btn = gr.Button("Run VOID", variant="primary")

        with gr.Column():
            output_video = gr.Video(label="Inpainted output", interactive=False)

    gr.Markdown(QUADMASK_EXPLAINER)

    gr.Examples(
        examples=EXAMPLES,
        inputs=[input_video, mask_video, prompt, num_steps, guidance_scale, seed],
        outputs=[output_video],
        cache_examples=False,
        label="Sample sequences β€” click to load inputs",
    )

    run_btn.click(
        fn=run_inpaint,
        inputs=[input_video, mask_video, prompt, num_steps, guidance_scale, seed],
        outputs=[output_video],
    )

if __name__ == "__main__":
    demo.launch()