| import torch.nn as nn |
| import sys,os |
| sys.path.append("..") |
| import torch |
| from datasets import build_dataset |
| from configs.config_utils import CONFIG |
| from torch.utils.data import DataLoader |
| from models.modules import PointEmbed |
| from models.modules import ConvPointnet_Encoder,ConvPointnet_Decoder |
| import numpy as np |
|
|
| class TriplaneVAE(nn.Module): |
| def __init__(self,opt): |
| super().__init__() |
| self.point_embedder=PointEmbed(hidden_dim=opt['point_emb_dim']) |
|
|
| encoder_args=opt['encoder'] |
| decoder_args=opt['decoder'] |
| self.encoder=ConvPointnet_Encoder(c_dim=encoder_args['plane_latent_dim'],dim=opt['point_emb_dim'],latent_dim=encoder_args['latent_dim'], |
| plane_resolution=encoder_args['plane_reso'],unet_kwargs=encoder_args['unet'],unet=True,padding=opt['padding']) |
| self.decoder=ConvPointnet_Decoder(latent_dim=decoder_args['latent_dim'],query_emb_dim=decoder_args['query_emb_dim'], |
| hidden_dim=decoder_args['hidden_dim'],unet_kwargs=decoder_args['unet'],n_blocks=decoder_args['n_blocks'], |
| plane_resolution=decoder_args['plane_reso'],padding=opt['padding']) |
|
|
| def forward(self,p,query): |
| ''' |
| :param p: surface points cloud of shape B,N,3 |
| :param query: sample points of shape B,N,3 |
| :return: |
| ''' |
| point_emb=self.point_embedder(p) |
| query_emb=self.point_embedder(query) |
| kl,plane_feat,means,logvars=self.encoder(p,point_emb) |
| if self.training: |
| if np.random.random()<0.5: |
| '''randomly sacle the triplane, and conduct triplane diffusion on 64x64x64 plane, promote robustness''' |
| plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode="bilinear") |
| plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode="bilinear") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| o=self.decoder(plane_feat,query,query_emb) |
|
|
| return {'logits':o,'kl':kl} |
|
|
|
|
| def decode(self,plane_feature,query): |
| query_embedding=self.point_embedder(query) |
| o=self.decoder(plane_feature,query,query_embedding) |
|
|
| return o |
|
|
| def encode(self,p): |
| point_emb = self.point_embedder(p) |
| kl, plane_feat,mean,logvar = self.encoder(p, point_emb) |
| '''p is point cloud of B,N,3''' |
| return plane_feat,kl,mean,logvar |
|
|
| if __name__=="__main__": |
| configs=CONFIG("../configs/train_triplane_vae_64.yaml") |
| config=configs.config |
| dataset_config=config['datasets'] |
| model_config=config["model"] |
| dataset=build_dataset("train",dataset_config) |
| dataset.__getitem__(0) |
| dataloader=DataLoader( |
| dataset=dataset, |
| batch_size=10, |
| shuffle=True, |
| num_workers=2, |
| ) |
| net=TriplaneVAE(model_config).float().cuda() |
| for idx,data_batch in enumerate(dataloader): |
| if idx==1: |
| break |
| surface=data_batch['surface'].float().cuda() |
| query=data_batch['points'].float().cuda() |
| net(surface,query) |
|
|
|
|
|
|