File size: 3,711 Bytes
99c8044 7cf6bed 99c8044 87bfad6 7cf6bed 87bfad6 99c8044 f5366ce 7cf6bed 87bfad6 99c8044 87bfad6 7cf6bed 87bfad6 2fd1c51 87bfad6 99c8044 87bfad6 2fd1c51 339daa1 87bfad6 99c8044 d4efc46 87bfad6 07239aa 99c8044 87bfad6 f2eecc0 2fd1c51 87bfad6 99c8044 87bfad6 f2eecc0 99c8044 f2eecc0 87bfad6 f2eecc0 87bfad6 5a6434e 87bfad6 5a6434e 2e7cf8e 87bfad6 2e7cf8e 87bfad6 99c8044 87bfad6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | """Prediction interface for per-game Flow-Warp-Mask models v12 with motion encoding + TTA."""
import sys
import os
import numpy as np
import torch
sys.path.insert(0, "/home/coder/code")
from flowmask_model import FlowWarpMaskUNet
from flownet_model import differentiable_warp
CONTEXT_LEN = 4
GAME_CONFIGS = {
"pong": {"channels": [32, 64, 128], "file": "pong_model.pt"},
"sonic": {"channels": [40, 80, 160], "file": "sonic_model.pt"},
"pole_position": {"channels": [24, 48, 96], "file": "pole_model.pt"},
}
def detect_game(context_frames):
mean_val = context_frames.mean()
if mean_val < 10:
return "pong"
elif mean_val < 80:
return "sonic"
else:
return "pole_position"
def load_model(model_dir: str):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models = {}
for game, cfg in GAME_CONFIGS.items():
model = FlowWarpMaskUNet(in_channels=12, channels=cfg["channels"])
model_path = os.path.join(model_dir, cfg["file"])
state_dict = torch.load(model_path, map_location=device, weights_only=True)
state_dict = {k: v.float() for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(device)
model.eval()
models[game] = model
return {"models": models, "device": device}
def _make_motion_input(frames):
"""Create motion encoding: last frame (3ch) + 3 pairwise diffs (9ch) = 12ch.
Args:
frames: (4, 3, H, W) tensor in [0,1]
Returns:
(12, H, W) tensor
"""
last = frames[-1] # (3, H, W)
diff1 = frames[-1] - frames[-2] # most recent motion
diff2 = frames[-2] - frames[-3] # previous motion
diff3 = frames[-3] - frames[-4] # older motion
return torch.cat([last, diff1, diff2, diff3], dim=0) # (12, H, W)
def _prepare_context(context_frames):
"""Prepare 4-frame context from numpy frames."""
if len(context_frames) >= CONTEXT_LEN:
frames = context_frames[-CONTEXT_LEN:]
else:
pad_count = CONTEXT_LEN - len(context_frames)
padding = np.stack([context_frames[0]] * pad_count, axis=0)
frames = np.concatenate([padding, context_frames], axis=0)
frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
frames_t = frames_t.permute(0, 3, 1, 2) # (4, 3, 64, 64)
return frames_t
def _run_model(model, frames_t, device):
"""Run model with motion encoding input."""
last_frame = frames_t[-1].unsqueeze(0) # (1, 3, 64, 64)
inp = _make_motion_input(frames_t).unsqueeze(0) # (1, 12, 64, 64)
inp = inp.to(device)
last_frame = last_frame.to(device)
flow, mask, gen_frame = model(inp)
warped = differentiable_warp(last_frame, flow)
pred = mask * warped + (1 - mask) * gen_frame
pred = torch.clamp(pred, 0, 1)
return pred
def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
models = model_dict["models"]
device = model_dict["device"]
game = detect_game(context_frames)
model = models[game]
frames_t = _prepare_context(context_frames)
with torch.no_grad():
# Original prediction
pred1 = _run_model(model, frames_t, device)
# TTA: horizontally flipped prediction
frames_flipped = frames_t.flip(-1)
pred2_flipped = _run_model(model, frames_flipped, device)
pred2 = pred2_flipped.flip(-1)
# Average
pred = (pred1 + pred2) / 2.0
pred = pred[0].cpu().permute(1, 2, 0).numpy()
pred = (pred * 255).clip(0, 255).astype(np.uint8)
# Post-processing for Pong: clamp dark pixels to pure black
if game == "pong":
pred[pred < 5] = 0
return pred
|