"""Prediction interface for per-game Flow-Warp-Mask models v12 with motion encoding + TTA.""" import sys import os import numpy as np import torch sys.path.insert(0, "/home/coder/code") from flowmask_model import FlowWarpMaskUNet from flownet_model import differentiable_warp CONTEXT_LEN = 4 GAME_CONFIGS = { "pong": {"channels": [32, 64, 128], "file": "pong_model.pt"}, "sonic": {"channels": [40, 80, 160], "file": "sonic_model.pt"}, "pole_position": {"channels": [24, 48, 96], "file": "pole_model.pt"}, } def detect_game(context_frames): mean_val = context_frames.mean() if mean_val < 10: return "pong" elif mean_val < 80: return "sonic" else: return "pole_position" def load_model(model_dir: str): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") models = {} for game, cfg in GAME_CONFIGS.items(): model = FlowWarpMaskUNet(in_channels=12, channels=cfg["channels"]) model_path = os.path.join(model_dir, cfg["file"]) state_dict = torch.load(model_path, map_location=device, weights_only=True) state_dict = {k: v.float() for k, v in state_dict.items()} model.load_state_dict(state_dict) model.to(device) model.eval() models[game] = model return {"models": models, "device": device} def _make_motion_input(frames): """Create motion encoding: last frame (3ch) + 3 pairwise diffs (9ch) = 12ch. Args: frames: (4, 3, H, W) tensor in [0,1] Returns: (12, H, W) tensor """ last = frames[-1] # (3, H, W) diff1 = frames[-1] - frames[-2] # most recent motion diff2 = frames[-2] - frames[-3] # previous motion diff3 = frames[-3] - frames[-4] # older motion return torch.cat([last, diff1, diff2, diff3], dim=0) # (12, H, W) def _prepare_context(context_frames): """Prepare 4-frame context from numpy frames.""" if len(context_frames) >= CONTEXT_LEN: frames = context_frames[-CONTEXT_LEN:] else: pad_count = CONTEXT_LEN - len(context_frames) padding = np.stack([context_frames[0]] * pad_count, axis=0) frames = np.concatenate([padding, context_frames], axis=0) frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0) frames_t = frames_t.permute(0, 3, 1, 2) # (4, 3, 64, 64) return frames_t def _run_model(model, frames_t, device): """Run model with motion encoding input.""" last_frame = frames_t[-1].unsqueeze(0) # (1, 3, 64, 64) inp = _make_motion_input(frames_t).unsqueeze(0) # (1, 12, 64, 64) inp = inp.to(device) last_frame = last_frame.to(device) flow, mask, gen_frame = model(inp) warped = differentiable_warp(last_frame, flow) pred = mask * warped + (1 - mask) * gen_frame pred = torch.clamp(pred, 0, 1) return pred def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray: models = model_dict["models"] device = model_dict["device"] game = detect_game(context_frames) model = models[game] frames_t = _prepare_context(context_frames) with torch.no_grad(): # Original prediction pred1 = _run_model(model, frames_t, device) # TTA: horizontally flipped prediction frames_flipped = frames_t.flip(-1) pred2_flipped = _run_model(model, frames_flipped, device) pred2 = pred2_flipped.flip(-1) # Average pred = (pred1 + pred2) / 2.0 pred = pred[0].cpu().permute(1, 2, 0).numpy() pred = (pred * 255).clip(0, 255).astype(np.uint8) # Post-processing for Pong: clamp dark pixels to pure black if game == "pong": pred[pred < 5] = 0 return pred