File size: 2,184 Bytes
46c40c4
792ff16
 
 
 
 
 
 
 
 
 
 
46c40c4
792ff16
 
 
 
 
 
46c40c4
792ff16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46c40c4
 
 
 
 
792ff16
46c40c4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Inference for AR curriculum model + TTA."""
import json
import numpy as np
import torch
import sys
sys.path.insert(0, "/home/coder/code")
from flow_warp_attn_model import FlowWarpAttnUNet


def load_model(model_dir: str):
    with open(f"{model_dir}/config.json") as f:
        config = json.load(f)
    model = FlowWarpAttnUNet(in_channels=config["in_channels"], channels=config["channels"])
    sd = torch.load(f"{model_dir}/model.pt", map_location="cpu", weights_only=True)
    sd = {k: v.float() for k, v in sd.items()}
    model.load_state_dict(sd)
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return {"model": model, "device": device, "context_len": config["context_len"]}


def _prepare_input(context_frames, context_len):
    N = len(context_frames)
    if N >= context_len:
        frames = context_frames[-context_len:]
    else:
        pad = np.repeat(context_frames[:1], context_len - N, axis=0)
        frames = np.concatenate([pad, context_frames], axis=0)
    frames_f = frames.astype(np.float32) / 255.0
    frames_f = np.transpose(frames_f, (0, 3, 1, 2))
    context = frames_f.reshape(1, -1, 64, 64)
    last_frame = frames_f[-1:]
    return context, last_frame


def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
    model = model_dict["model"]
    device = model_dict["device"]
    context_len = model_dict["context_len"]

    ctx, last = _prepare_input(context_frames, context_len)
    with torch.no_grad():
        ctx_t = torch.from_numpy(ctx).to(device)
        last_t = torch.from_numpy(last).to(device)
        pred1, _ = model(ctx_t, last_t)

    flipped_frames = context_frames[:, :, ::-1, :].copy()
    ctx_f, last_f = _prepare_input(flipped_frames, context_len)
    with torch.no_grad():
        ctx_ft = torch.from_numpy(ctx_f).to(device)
        last_ft = torch.from_numpy(last_f).to(device)
        pred2, _ = model(ctx_ft, last_ft)
        pred2 = pred2.flip(-1)

    pred = (pred1 + pred2) / 2.0
    pred_np = pred[0].cpu().numpy()
    pred_np = np.transpose(pred_np, (1, 2, 0))
    return (pred_np * 255.0).clip(0, 255).astype(np.uint8)