| import argparse |
| import sys |
| sys.path.append("..") |
| sys.path.append(".") |
| import numpy as np |
|
|
| import mcubes |
| import os |
| import torch |
|
|
| import trimesh |
|
|
| from datasets.SingleView_dataset import Object_PartialPoints_MultiImg |
| from datasets.transforms import Scale_Shift_Rotate |
| from models import get_model |
| from pathlib import Path |
| import open3d as o3d |
| from configs.config_utils import CONFIG |
| import cv2 |
| from util.misc import MetricLogger |
| import scipy |
| from pyTorchChamferDistance.chamfer_distance import ChamferDistance |
| from util.projection_utils import draw_proj_image |
| from util import misc |
| import time |
| dist_chamfer=ChamferDistance() |
|
|
|
|
| def pc_metrics(p1, p2, space_ext=2, fscore_param=0.01, scale=.5): |
| """ p2: reference ponits |
| (B, N, 3) |
| """ |
| p1, p2, space_ext = p1 * scale, p2 * scale, space_ext * scale |
| f_thresh = space_ext * fscore_param |
|
|
| |
| d1, d2, _, _ = dist_chamfer(p1, p2) |
| |
| d1sqrt, d2sqrt = (d1 ** .5), (d2 ** .5) |
| chamfer_L1 = d1sqrt.mean(axis=-1) + d2sqrt.mean(axis=-1) |
| chamfer_L2 = d1.mean(axis=-1) + d2.mean(axis=-1) |
| precision = (d1sqrt < f_thresh).sum(axis=-1).float() / p1.shape[1] |
| recall = (d2sqrt < f_thresh).sum(axis=-1).float() / p2.shape[1] |
| |
| fscore = 2 * torch.div(recall * precision, recall + precision) |
| fscore[fscore == float("inf")] = 0 |
| return chamfer_L1,chamfer_L2,fscore |
|
|
| if __name__ == "__main__": |
|
|
| parser = argparse.ArgumentParser('this script can be used to compute iou fscore chamfer distance before icp align', add_help=False) |
| parser.add_argument('--configs',type=str,required=True) |
| parser.add_argument('--output_folder', type=str, default="../output_result/Triplane_diff_parcond_0926") |
| parser.add_argument('--dm-pth',type=str) |
| parser.add_argument('--ae-pth',type=str) |
| parser.add_argument('--data-pth', type=str,default="../") |
| parser.add_argument('--save_mesh',action="store_true",default=False) |
| parser.add_argument('--save_image',action="store_true",default=False) |
| parser.add_argument('--save_par_points', action="store_true", default=False) |
| parser.add_argument('--save_proj_img',action="store_true",default=False) |
| parser.add_argument('--save_surface',action="store_true",default=False) |
| parser.add_argument('--reso',default=128,type=int) |
| parser.add_argument('--category',nargs="+",type=str) |
| parser.add_argument('--eval_cd',action="store_true",default=False) |
| parser.add_argument('--use_augmentation',action="store_true",default=False) |
|
|
| parser.add_argument('--world_size', default=1, type=int, |
| help='number of distributed processes') |
| parser.add_argument('--local_rank', default=-1, type=int) |
| parser.add_argument('--dist_on_itp', action='store_true') |
| parser.add_argument('--dist_url', default='env://', |
| help='url used to set up distributed training') |
| parser.add_argument('--device', default='cuda', |
| help='device to use for training / testing') |
| args = parser.parse_args() |
| misc.init_distributed_mode(args) |
| config_path=args.configs |
| config=CONFIG(config_path) |
| dataset_config=config.config['dataset'] |
| dataset_config['data_path']=args.data_pth |
| if "arkit" in args.category[0]: |
| split_filename=dataset_config['keyword']+'_val_par_img.json' |
| else: |
| split_filename='val_par_img.json' |
|
|
| transform = None |
| if args.use_augmentation: |
| transform=Scale_Shift_Rotate(jitter_partial=False,jitter=False,use_scale=False,angle=(-10,10),shift=(-0.1,0.1)) |
| dataset_val = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename=split_filename,categories=args.category,split="val", |
| transform=transform, sampling=False, |
| num_samples=1024, return_surface=True,ret_sample=True, |
| surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],surface_size=100000, |
| load_proj_mat=True,load_image=True,load_org_img=True,load_triplane=None,par_point_aug=None,replica=1) |
| batch_size=1 |
|
|
| num_tasks = misc.get_world_size() |
| global_rank = misc.get_rank() |
| val_sampler = torch.utils.data.DistributedSampler( |
| dataset_val, num_replicas=num_tasks, rank=global_rank, |
| shuffle=False) |
| dataloader_val=torch.utils.data.DataLoader( |
| dataset_val, |
| sampler=val_sampler, |
| batch_size=batch_size, |
| num_workers=10, |
| shuffle=False, |
| ) |
| output_folder=args.output_folder |
|
|
| device = torch.device('cuda') |
|
|
| ae_config=config.config['model']['ae'] |
| dm_config=config.config['model']['dm'] |
| ae_model=get_model(ae_config).to(device) |
| if args.category[0] == "all": |
| dm_config["use_cat_embedding"]=True |
| else: |
| dm_config["use_cat_embedding"] = False |
| dm_model=get_model(dm_config).to(device) |
| ae_model.eval() |
| dm_model.eval() |
| ae_model.load_state_dict(torch.load(args.ae_pth)['model']) |
| dm_model.load_state_dict(torch.load(args.dm_pth)['model']) |
|
|
| density = args.reso |
| gap = 2.2 / density |
| x = np.linspace(-1.1, 1.1, int(density + 1)) |
| y = np.linspace(-1.1, 1.1, int(density + 1)) |
| z = np.linspace(-1.1, 1.1, int(density + 1)) |
| xv, yv, zv = np.meshgrid(x, y, z,indexing='ij') |
| grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,non_blocking=True) |
|
|
| metric_logger=MetricLogger(delimiter=" ") |
| header = 'Test:' |
|
|
| with torch.no_grad(): |
| for data_batch in metric_logger.log_every(dataloader_val,10, header): |
| |
| |
| partial_name = data_batch['partial_name'] |
| class_name = data_batch['class_name'] |
| model_ids=data_batch['model_id'] |
| surface=data_batch['surface'] |
| proj_matrices=data_batch['proj_mat'] |
| sample_points=data_batch["points"].cuda().float() |
| labels=data_batch["labels"].cuda().float() |
| sample_input=dm_model.prepare_sample_data(data_batch) |
| |
| sampled_array = dm_model.sample(sample_input,num_steps=36).float() |
| |
| |
| |
| sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear") |
| for j in range(sampled_array.shape[0]): |
| if args.save_mesh | args.save_par_points | args.save_image: |
| object_folder = os.path.join(output_folder, class_name[j], model_ids[j]) |
| Path(object_folder).mkdir(parents=True, exist_ok=True) |
| '''calculate iou''' |
| sample_point=sample_points[j:j+1] |
| sample_output=ae_model.decode(sampled_array[j:j + 1],sample_point) |
| sample_pred=torch.zeros_like(sample_output) |
| sample_pred[sample_output>=0.0]=1 |
| label=labels[j:j+1] |
| intersection = (sample_pred * label).sum(dim=1) |
| union = (sample_pred + label).gt(0).sum(dim=1) |
| iou = intersection * 1.0 / union + 1e-5 |
| iou = iou.mean() |
| metric_logger.update(iou=iou.item()) |
|
|
| if args.use_augmentation: |
| tran_mat=data_batch["tran_mat"][j].numpy() |
| mat_save_path='{}/tran_mat.npy'.format(object_folder) |
| np.save(mat_save_path,tran_mat) |
|
|
| if args.eval_cd: |
| grid_list=torch.split(grid,128**3,dim=1) |
| output_list=[] |
| |
| for sub_grid in grid_list: |
| output_list.append(ae_model.decode(sampled_array[j:j + 1],sub_grid)) |
| output=torch.cat(output_list,dim=1) |
| |
| |
| |
| logits = output[j].detach() |
|
|
| volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy() |
| verts, faces = mcubes.marching_cubes(volume, 0) |
|
|
| verts *= gap |
| verts -= 1.1 |
| |
|
|
|
|
| m = trimesh.Trimesh(verts, faces) |
| '''calculate fscore and chamfer distance''' |
| result_surface,_=trimesh.sample.sample_surface(m,100000) |
| gt_surface=surface[j] |
| assert gt_surface.shape[0]==result_surface.shape[0] |
|
|
| result_surface_gpu = torch.from_numpy(result_surface).float().cuda().unsqueeze(0) |
| gt_surface_gpu = gt_surface.float().cuda().unsqueeze(0) |
| _,chamfer_L2,fscore=pc_metrics(result_surface_gpu,gt_surface_gpu) |
| metric_logger.update(chamferl2=chamfer_L2*1000.0) |
| metric_logger.update(fscore=fscore) |
|
|
| if args.save_mesh: |
| m.export('{}/{}_mesh.ply'.format(object_folder, partial_name[j])) |
|
|
| if args.save_par_points: |
| par_point_input = data_batch['par_points'][j].numpy() |
| |
| par_point_o3d = o3d.geometry.PointCloud() |
| par_point_o3d.points = o3d.utility.Vector3dVector(par_point_input[:, 0:3]) |
| o3d.io.write_point_cloud('{}/{}.ply'.format(object_folder, partial_name[j]), par_point_o3d) |
| if args.save_image: |
| image_list=data_batch["org_image"] |
| for idx,image in enumerate(image_list): |
| image=image[0].numpy().astype(np.uint8) |
| if args.save_proj_img: |
| proj_mat=proj_matrices[j,idx].numpy() |
| proj_image=draw_proj_image(image,proj_mat,result_surface) |
| proj_save_path = '{}/proj_{}.jpg'.format(object_folder, idx) |
| cv2.imwrite(proj_save_path,proj_image) |
| save_path='{}/{}.jpg'.format(object_folder, idx) |
| cv2.imwrite(save_path,image) |
| if args.save_surface: |
| surface=gt_surface.numpy().astype(np.float32) |
| surface_o3d = o3d.geometry.PointCloud() |
| surface_o3d.points = o3d.utility.Vector3dVector(surface[:, 0:3]) |
| o3d.io.write_point_cloud('{}/surface.ply'.format(object_folder), surface_o3d) |
| metric_logger.synchronize_between_processes() |
| print('* iou {ious.global_avg:.3f}' |
| .format(ious=metric_logger.iou)) |
| if args.eval_cd: |
| print('* chamferl2 {chamferl2s.global_avg:.3f}' |
| .format(chamferl2s=metric_logger.chamferl2)) |
| print('* fscore {fscores.global_avg:.3f}' |
| .format(fscores=metric_logger.fscore)) |
|
|