| import sys |
| import os |
|
|
| |
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| _PROJECT_ROOT = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) |
| if _PROJECT_ROOT not in sys.path: |
| sys.path.insert(0, _PROJECT_ROOT) |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| import torchvision.transforms as T |
| from omegaconf import OmegaConf |
| import fire |
|
|
| def init_fn(config_path): |
| from utils import instantiate |
| transform = T.Compose([T.Resize((512, 512)), T.ToTensor(), T.Normalize([0.5], [0.5])]) |
| config = OmegaConf.load(config_path) |
| module = instantiate(config.model, instantiate_module=False) |
| model = module(config=config) |
| checkpoint = torch.load(config.resume_ckpt, map_location="cpu") |
| model.load_state_dict(checkpoint["state_dict"], strict=False) |
| model.eval() |
| motion_encoder = model.motion_encoder |
| return {"transform": transform, "motion_encoder": motion_encoder} |
|
|
| def extract_motion_latent( |
| mask_image_path='./test_case/test_img_masked.png', |
| config_path='./configs/head_animator_best_0506.yaml', |
| save_npz_path='./test_case/test_img_resize.npz', |
| version="0506"): |
| sys.path.insert(0, f'./utils/model_{version}') |
| config_path = config_path.replace("0506", version) |
| context = init_fn(config_path) |
| transform = context["transform"] |
| motion_encoder = context["motion_encoder"] |
| img = Image.open(mask_image_path).convert("RGB") |
| img_tensor = transform(img).unsqueeze(0) |
| with torch.no_grad(): |
| latent = motion_encoder(img_tensor)[0] |
| latent_np = latent.numpy() |
|
|
| |
| if os.path.exists(save_npz_path): |
| existing_data = np.load(save_npz_path, allow_pickle=True) |
| data_dict = dict(existing_data) |
| existing_data.close() |
| else: |
| data_dict = {} |
|
|
| |
| data_dict.update({ |
| 'video_id': os.path.basename(save_npz_path)[:-4], |
| 'mask_img_path': mask_image_path, |
| 'ref_img_path': save_npz_path.replace('npz', 'png'), |
| 'motion_latent': latent_np |
| }) |
|
|
| |
| np.savez(save_npz_path, **data_dict) |
| |
| |
| |
| |
| |
| |
| |
| if __name__ == '__main__': |
| fire.Fire(extract_motion_latent) |
|
|