| import os |
| import sys |
| import numpy as np |
| from PIL import Image |
| import torch |
| import torchvision.transforms as transforms |
| from argparse import Namespace |
| from e4e.models.psp import pSp |
| from util import * |
|
|
|
|
|
|
| @ torch.no_grad() |
| def projection(img, name, device='cuda'): |
|
|
|
|
| model_path = 'e4e_ffhq_encode.pt' |
| ckpt = torch.load(model_path, map_location='cpu') |
| opts = ckpt['opts'] |
| opts['checkpoint_path'] = model_path |
| opts= Namespace(**opts) |
| net = pSp(opts, device).eval().to(device) |
|
|
| transform = transforms.Compose( |
| [ |
| transforms.Resize(256), |
| transforms.CenterCrop(256), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ] |
| ) |
|
|
| img = transform(img).unsqueeze(0).to(device) |
| images, w_plus = net(img, randomize_noise=False, return_latents=True) |
| result_file = {} |
| result_file['latent'] = w_plus[0] |
| torch.save(result_file, name) |
| return w_plus[0] |
|
|