"""Inference for AR curriculum model + TTA.""" import json import numpy as np import torch import sys sys.path.insert(0, "/home/coder/code") from flow_warp_attn_model import FlowWarpAttnUNet def load_model(model_dir: str): with open(f"{model_dir}/config.json") as f: config = json.load(f) model = FlowWarpAttnUNet(in_channels=config["in_channels"], channels=config["channels"]) sd = torch.load(f"{model_dir}/model.pt", map_location="cpu", weights_only=True) sd = {k: v.float() for k, v in sd.items()} model.load_state_dict(sd) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) return {"model": model, "device": device, "context_len": config["context_len"]} def _prepare_input(context_frames, context_len): N = len(context_frames) if N >= context_len: frames = context_frames[-context_len:] else: pad = np.repeat(context_frames[:1], context_len - N, axis=0) frames = np.concatenate([pad, context_frames], axis=0) frames_f = frames.astype(np.float32) / 255.0 frames_f = np.transpose(frames_f, (0, 3, 1, 2)) context = frames_f.reshape(1, -1, 64, 64) last_frame = frames_f[-1:] return context, last_frame def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray: model = model_dict["model"] device = model_dict["device"] context_len = model_dict["context_len"] ctx, last = _prepare_input(context_frames, context_len) with torch.no_grad(): ctx_t = torch.from_numpy(ctx).to(device) last_t = torch.from_numpy(last).to(device) pred1, _ = model(ctx_t, last_t) flipped_frames = context_frames[:, :, ::-1, :].copy() ctx_f, last_f = _prepare_input(flipped_frames, context_len) with torch.no_grad(): ctx_ft = torch.from_numpy(ctx_f).to(device) last_ft = torch.from_numpy(last_f).to(device) pred2, _ = model(ctx_ft, last_ft) pred2 = pred2.flip(-1) pred = (pred1 + pred2) / 2.0 pred_np = pred[0].cpu().numpy() pred_np = np.transpose(pred_np, (1, 2, 0)) return (pred_np * 255.0).clip(0, 255).astype(np.uint8)