| import os |
| import math |
| import cv2 |
| import trimesh |
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import nvdiffrast.torch as dr |
| from mesh import Mesh, safe_normalize |
|
|
| def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'): |
| assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" |
| y = x.permute(0, 3, 1, 2) |
| if x.shape[1] > size[0] and x.shape[2] > size[1]: |
| y = torch.nn.functional.interpolate(y, size, mode=min) |
| else: |
| if mag == 'bilinear' or mag == 'bicubic': |
| y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) |
| else: |
| y = torch.nn.functional.interpolate(y, size, mode=mag) |
| return y.permute(0, 2, 3, 1).contiguous() |
|
|
| def scale_img_hwc(x, size, mag='bilinear', min='bilinear'): |
| return scale_img_nhwc(x[None, ...], size, mag, min)[0] |
|
|
| def scale_img_nhw(x, size, mag='bilinear', min='bilinear'): |
| return scale_img_nhwc(x[..., None], size, mag, min)[..., 0] |
|
|
| def scale_img_hw(x, size, mag='bilinear', min='bilinear'): |
| return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0] |
|
|
| def trunc_rev_sigmoid(x, eps=1e-6): |
| x = x.clamp(eps, 1 - eps) |
| return torch.log(x / (1 - x)) |
|
|
| def make_divisible(x, m=8): |
| return int(math.ceil(x / m) * m) |
|
|
| class Renderer(nn.Module): |
| def __init__(self, opt): |
| |
| super().__init__() |
|
|
| self.opt = opt |
|
|
| self.mesh = Mesh.load(self.opt.mesh, resize=False) |
|
|
| if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'): |
| self.glctx = dr.RasterizeGLContext() |
| else: |
| self.glctx = dr.RasterizeCudaContext() |
| |
| |
| self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)) |
| self.raw_albedo = nn.Parameter(trunc_rev_sigmoid(self.mesh.albedo)) |
|
|
|
|
| def get_params(self): |
|
|
| params = [ |
| {'params': self.raw_albedo, 'lr': self.opt.texture_lr}, |
| ] |
|
|
| if self.opt.train_geo: |
| params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr}) |
|
|
| return params |
|
|
| @torch.no_grad() |
| def export_mesh(self, save_path): |
| self.mesh.v = (self.mesh.v + self.v_offsets).detach() |
| self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach()) |
| self.mesh.write(save_path) |
|
|
| |
| def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'): |
| |
| |
| if ssaa != 1: |
| h = make_divisible(h0 * ssaa, 8) |
| w = make_divisible(w0 * ssaa, 8) |
| else: |
| h, w = h0, w0 |
| |
| results = {} |
|
|
| |
| if self.opt.train_geo: |
| v = self.mesh.v + self.v_offsets |
| else: |
| v = self.mesh.v |
|
|
| pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) |
| proj = torch.from_numpy(proj.astype(np.float32)).to(v.device) |
|
|
| |
| v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) |
| v_clip = v_cam @ proj.T |
|
|
| rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w)) |
|
|
| alpha = (rast[0, ..., 3:] > 0).float() |
| depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) |
| depth = depth.squeeze(0) |
|
|
| texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all') |
| albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) |
| albedo = torch.sigmoid(albedo) |
| |
| if self.opt.train_geo: |
| i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long() |
| v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] |
|
|
| face_normals = torch.cross(v1 - v0, v2 - v0) |
| face_normals = safe_normalize(face_normals) |
| |
| vn = torch.zeros_like(v) |
| vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) |
| vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) |
| vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) |
|
|
| vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) |
| else: |
| vn = self.mesh.vn |
| |
| normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.fn) |
| normal = safe_normalize(normal[0]) |
|
|
| |
| rot_normal = normal @ pose[:3, :3] |
| viewcos = rot_normal[..., [2]] |
|
|
| |
| albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0) |
| albedo = alpha * albedo + (1 - alpha) * bg_color |
|
|
| |
| if ssaa != 1: |
| albedo = scale_img_hwc(albedo, (h0, w0)) |
| alpha = scale_img_hwc(alpha, (h0, w0)) |
| depth = scale_img_hwc(depth, (h0, w0)) |
| normal = scale_img_hwc(normal, (h0, w0)) |
| viewcos = scale_img_hwc(viewcos, (h0, w0)) |
|
|
| results['image'] = albedo.clamp(0, 1) |
| results['alpha'] = alpha |
| results['depth'] = depth |
| results['normal'] = (normal + 1) / 2 |
| results['viewcos'] = viewcos |
|
|
| return results |