File size: 3,711 Bytes
99c8044
7cf6bed
 
 
 
 
 
99c8044
87bfad6
7cf6bed
87bfad6
99c8044
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5366ce
 
7cf6bed
87bfad6
99c8044
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87bfad6
 
7cf6bed
87bfad6
 
 
2fd1c51
87bfad6
99c8044
87bfad6
2fd1c51
339daa1
87bfad6
99c8044
 
 
d4efc46
87bfad6
 
07239aa
99c8044
87bfad6
 
 
 
f2eecc0
2fd1c51
87bfad6
99c8044
87bfad6
f2eecc0
99c8044
 
 
 
f2eecc0
87bfad6
 
 
f2eecc0
87bfad6
5a6434e
87bfad6
5a6434e
2e7cf8e
87bfad6
 
2e7cf8e
87bfad6
 
99c8044
 
 
 
 
87bfad6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Prediction interface for per-game Flow-Warp-Mask models v12 with motion encoding + TTA."""
import sys
import os
import numpy as np
import torch

sys.path.insert(0, "/home/coder/code")
from flowmask_model import FlowWarpMaskUNet
from flownet_model import differentiable_warp

CONTEXT_LEN = 4
GAME_CONFIGS = {
    "pong": {"channels": [32, 64, 128], "file": "pong_model.pt"},
    "sonic": {"channels": [40, 80, 160], "file": "sonic_model.pt"},
    "pole_position": {"channels": [24, 48, 96], "file": "pole_model.pt"},
}


def detect_game(context_frames):
    mean_val = context_frames.mean()
    if mean_val < 10:
        return "pong"
    elif mean_val < 80:
        return "sonic"
    else:
        return "pole_position"


def load_model(model_dir: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    models = {}
    for game, cfg in GAME_CONFIGS.items():
        model = FlowWarpMaskUNet(in_channels=12, channels=cfg["channels"])
        model_path = os.path.join(model_dir, cfg["file"])
        state_dict = torch.load(model_path, map_location=device, weights_only=True)
        state_dict = {k: v.float() for k, v in state_dict.items()}
        model.load_state_dict(state_dict)
        model.to(device)
        model.eval()
        models[game] = model
    return {"models": models, "device": device}


def _make_motion_input(frames):
    """Create motion encoding: last frame (3ch) + 3 pairwise diffs (9ch) = 12ch.

    Args:
        frames: (4, 3, H, W) tensor in [0,1]
    Returns:
        (12, H, W) tensor
    """
    last = frames[-1]           # (3, H, W)
    diff1 = frames[-1] - frames[-2]  # most recent motion
    diff2 = frames[-2] - frames[-3]  # previous motion
    diff3 = frames[-3] - frames[-4]  # older motion
    return torch.cat([last, diff1, diff2, diff3], dim=0)  # (12, H, W)


def _prepare_context(context_frames):
    """Prepare 4-frame context from numpy frames."""
    if len(context_frames) >= CONTEXT_LEN:
        frames = context_frames[-CONTEXT_LEN:]
    else:
        pad_count = CONTEXT_LEN - len(context_frames)
        padding = np.stack([context_frames[0]] * pad_count, axis=0)
        frames = np.concatenate([padding, context_frames], axis=0)

    frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
    frames_t = frames_t.permute(0, 3, 1, 2)  # (4, 3, 64, 64)
    return frames_t


def _run_model(model, frames_t, device):
    """Run model with motion encoding input."""
    last_frame = frames_t[-1].unsqueeze(0)  # (1, 3, 64, 64)
    inp = _make_motion_input(frames_t).unsqueeze(0)  # (1, 12, 64, 64)

    inp = inp.to(device)
    last_frame = last_frame.to(device)

    flow, mask, gen_frame = model(inp)
    warped = differentiable_warp(last_frame, flow)
    pred = mask * warped + (1 - mask) * gen_frame
    pred = torch.clamp(pred, 0, 1)
    return pred


def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
    models = model_dict["models"]
    device = model_dict["device"]

    game = detect_game(context_frames)
    model = models[game]

    frames_t = _prepare_context(context_frames)

    with torch.no_grad():
        # Original prediction
        pred1 = _run_model(model, frames_t, device)

        # TTA: horizontally flipped prediction
        frames_flipped = frames_t.flip(-1)
        pred2_flipped = _run_model(model, frames_flipped, device)
        pred2 = pred2_flipped.flip(-1)

        # Average
        pred = (pred1 + pred2) / 2.0

    pred = pred[0].cpu().permute(1, 2, 0).numpy()
    pred = (pred * 255).clip(0, 255).astype(np.uint8)

    # Post-processing for Pong: clamp dark pixels to pure black
    if game == "pong":
        pred[pred < 5] = 0

    return pred