world-model / predict.py
ojaffe's picture
Upload folder using huggingface_hub
99c8044 verified
"""Prediction interface for per-game Flow-Warp-Mask models v12 with motion encoding + TTA."""
import sys
import os
import numpy as np
import torch
sys.path.insert(0, "/home/coder/code")
from flowmask_model import FlowWarpMaskUNet
from flownet_model import differentiable_warp
CONTEXT_LEN = 4
GAME_CONFIGS = {
"pong": {"channels": [32, 64, 128], "file": "pong_model.pt"},
"sonic": {"channels": [40, 80, 160], "file": "sonic_model.pt"},
"pole_position": {"channels": [24, 48, 96], "file": "pole_model.pt"},
}
def detect_game(context_frames):
mean_val = context_frames.mean()
if mean_val < 10:
return "pong"
elif mean_val < 80:
return "sonic"
else:
return "pole_position"
def load_model(model_dir: str):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models = {}
for game, cfg in GAME_CONFIGS.items():
model = FlowWarpMaskUNet(in_channels=12, channels=cfg["channels"])
model_path = os.path.join(model_dir, cfg["file"])
state_dict = torch.load(model_path, map_location=device, weights_only=True)
state_dict = {k: v.float() for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(device)
model.eval()
models[game] = model
return {"models": models, "device": device}
def _make_motion_input(frames):
"""Create motion encoding: last frame (3ch) + 3 pairwise diffs (9ch) = 12ch.
Args:
frames: (4, 3, H, W) tensor in [0,1]
Returns:
(12, H, W) tensor
"""
last = frames[-1] # (3, H, W)
diff1 = frames[-1] - frames[-2] # most recent motion
diff2 = frames[-2] - frames[-3] # previous motion
diff3 = frames[-3] - frames[-4] # older motion
return torch.cat([last, diff1, diff2, diff3], dim=0) # (12, H, W)
def _prepare_context(context_frames):
"""Prepare 4-frame context from numpy frames."""
if len(context_frames) >= CONTEXT_LEN:
frames = context_frames[-CONTEXT_LEN:]
else:
pad_count = CONTEXT_LEN - len(context_frames)
padding = np.stack([context_frames[0]] * pad_count, axis=0)
frames = np.concatenate([padding, context_frames], axis=0)
frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0)
frames_t = frames_t.permute(0, 3, 1, 2) # (4, 3, 64, 64)
return frames_t
def _run_model(model, frames_t, device):
"""Run model with motion encoding input."""
last_frame = frames_t[-1].unsqueeze(0) # (1, 3, 64, 64)
inp = _make_motion_input(frames_t).unsqueeze(0) # (1, 12, 64, 64)
inp = inp.to(device)
last_frame = last_frame.to(device)
flow, mask, gen_frame = model(inp)
warped = differentiable_warp(last_frame, flow)
pred = mask * warped + (1 - mask) * gen_frame
pred = torch.clamp(pred, 0, 1)
return pred
def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
models = model_dict["models"]
device = model_dict["device"]
game = detect_game(context_frames)
model = models[game]
frames_t = _prepare_context(context_frames)
with torch.no_grad():
# Original prediction
pred1 = _run_model(model, frames_t, device)
# TTA: horizontally flipped prediction
frames_flipped = frames_t.flip(-1)
pred2_flipped = _run_model(model, frames_flipped, device)
pred2 = pred2_flipped.flip(-1)
# Average
pred = (pred1 + pred2) / 2.0
pred = pred[0].cpu().permute(1, 2, 0).numpy()
pred = (pred * 255).clip(0, 255).astype(np.uint8)
# Post-processing for Pong: clamp dark pixels to pure black
if game == "pong":
pred[pred < 5] = 0
return pred