Spaces:
Sleeping
Sleeping
| import pointCloudToMesh as ply2M | |
| import argparse | |
| import utils | |
| import graph_io as gio | |
| from clusters import * | |
| #from tqdm import tqdm,trange | |
| import splat_mesh_helpers as splt | |
| import clusters as cl | |
| from torch_geometric.data import Data | |
| from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator | |
| import pointCloudToMesh as plyToMesh | |
| import plotly.graph_objects as go | |
| import pyvista as pv | |
| from time import time | |
| from graph_networks.LinearStyleTransfer_vgg import encoder,decoder | |
| from graph_networks.LinearStyleTransfer_matrix import TransformLayer | |
| from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer | |
| from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4 | |
| #import matplotlib.pyplot as plt | |
| def styletransfer_with_filtering_sampling(filename, stylePath, outPath, device = 'cpu', threshold=99.9, samplingRate=1.5, displayPointCloud = False): | |
| print("Running on device:", device) | |
| n = 25 | |
| style_ref = utils.loadImage(stylePath, shape=(256*2,256*2)) | |
| ratio=.25 | |
| depth = 3 | |
| pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = splt.splat_unpacker_with_threshold(n, filename, threshold) | |
| time1_start = time() | |
| #plyToMesh.graph_Points(pos3D_Original, torch.clamp(colors_Original, 0, 1)) | |
| if samplingRate > 1: | |
| GaussianSamples = int(pos3D_Original.shape[0]*samplingRate) | |
| pos3D, colors = splt.splat_GaussianSuperSampler(pos3D_Original.clone(), colors_Original.clone(), opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(), GaussianSamples) | |
| else: | |
| pos3D, colors = pos3D_Original, colors_Original | |
| #plyToMesh.graph_Points(pos3D, torch.clamp(colors, 0, 1)) | |
| #plyToMesh.graph_Points(pos3D_Original, torch.clamp(colors_Original, 0, 1)) | |
| time1_end = time() | |
| print("Number of nodes in the graph:", pos3D.shape[0]) | |
| print(f"Time taken for Gaussian Super Sampling: {time1_end - time1_start}") | |
| if (displayPointCloud): | |
| #point cloud | |
| point_cloud = pv.PolyData(pos3D.numpy()) | |
| # Add colors to the point data | |
| point_cloud.point_data['colors'] = torch.clamp(colors, 0, 3).numpy() | |
| # Plot the point cloud | |
| plotter = pv.Plotter() | |
| plotter.add_points(point_cloud, scalars='colors', rgb=True, point_size=0.05) | |
| plotter.show_axes() | |
| plotter.show() | |
| time2_start = time() | |
| #find normals | |
| normalsNP = ply2M.Estimate_Normals(pos3D, threshold) | |
| normals = torch.from_numpy(normalsNP) | |
| #print("Time to compute normals:", time() - time2_start) | |
| up_vector = torch.tensor([[1,1,1]],dtype=torch.float) | |
| #up_vector = 2*torch.rand((1,3))-1 | |
| up_vector = up_vector/torch.linalg.norm(up_vector,dim=1) | |
| pos3D.to(device) | |
| colors.to(device) | |
| normals.to(device) | |
| up_vector.to(device) | |
| # Build initial graph | |
| #edge_index are neighbors of a point, directions are the directions from that point | |
| edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16) | |
| #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv | |
| edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) | |
| # Generate info for downsampled versions of the graph | |
| clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) | |
| #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) | |
| time2_end = time() | |
| print(f"Time taken for graph construction: {time2_end - time2_start}") | |
| time3_start = time() | |
| # Make final graph and metadata needed for mapping the result after going through the network | |
| content = Data(x=colors,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list) | |
| content_meta = Data(pos3D=pos3D) | |
| style,_ = gio.image2Graph(style_ref,depth=3,device=device) | |
| # Load original network | |
| enc_ref = encoder4() | |
| dec_ref = decoder4() | |
| matrix_ref = MulLayer('r41') | |
| enc_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/vgg_r41.pth')) | |
| dec_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/dec_r41.pth')) | |
| matrix_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/r41.pth',map_location=torch.device(device))) | |
| # Copy weights to graph network | |
| enc = encoder(padding_mode="replicate") | |
| dec = decoder(padding_mode="replicate") | |
| matrix = TransformLayer() | |
| with torch.no_grad(): | |
| enc.copy_weights(enc_ref) | |
| dec.copy_weights(dec_ref) | |
| matrix.copy_weights(matrix_ref) | |
| content = content.to(device) | |
| style = style.to(device) | |
| enc = enc.to(device) | |
| dec = dec.to(device) | |
| matrix = matrix.to(device) | |
| # Run graph network | |
| with torch.no_grad(): | |
| cF = enc(content) | |
| sF = enc(style) | |
| feature,transmatrix = matrix(cF['r41'],sF['r41'], | |
| content.edge_indexes[3],content.selections_list[3], | |
| style.edge_indexes[3],style.selections_list[3], | |
| content.interps_list[3] if hasattr(content,'interps_list') else None) | |
| result = dec(feature,content) | |
| result = result.clamp(0,1) | |
| colors[:, 0:3] = result | |
| time3_end = time() | |
| print(f"Time taken for stylization: {time3_end - time3_start}") | |
| if (displayPointCloud): | |
| #point cloud | |
| point_cloud = pv.PolyData(pos3D.numpy()) | |
| # Add colors to the point data | |
| point_cloud.point_data['colors'] = torch.clamp(colors, 0, 3).numpy() | |
| # Plot the point cloud | |
| plotter = pv.Plotter() | |
| plotter.add_points(point_cloud, scalars='colors', rgb=True, point_size=0.25) | |
| plotter.show_axes() | |
| plotter.show() | |
| time4_start = time() | |
| #create the interpolator | |
| interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu()) | |
| results_OriginalNP = interp2(pos3D_Original) | |
| results_OriginalNP64 = torch.from_numpy(results_OriginalNP) | |
| results_Original = results_OriginalNP64.to(torch.float32) | |
| colors_and_opacity_Original = torch.cat((results_Original, opacity_Original.unsqueeze(1)), dim=1) | |
| time4_end = time() | |
| print(f"Time taken for interpolation: {time4_end - time4_start}") | |
| # Save/show result | |
| splt.splat_save(pos3D_Original.numpy(), scales_Original.numpy(), rots_Original.numpy(), colors_and_opacity_Original.numpy(), outPath, fileType) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "filename", | |
| type=str, | |
| default='' | |
| ) | |
| parser.add_argument( | |
| "--stylePath", | |
| type=str, | |
| default="style_ims/style0.jpg" | |
| ) | |
| parser.add_argument( | |
| "--outPath", | |
| type=str, | |
| default='output.splat' | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| default= 0 if torch.cuda.is_available() else "cpu", | |
| choices=list(range(torch.cuda.device_count())) + ["cpu"] or ["cpu"] | |
| ) | |
| parser.add_argument( | |
| "--threshold", | |
| type=float, | |
| default=99.8 | |
| ) | |
| parser.add_argument( | |
| "--samplingRate", | |
| type=float, | |
| default=1.5 | |
| ) | |
| parser.add_argument( | |
| "--displayPointCloud", | |
| action='store_true' | |
| ) | |
| args = parser.parse_args() | |
| styletransfer_with_filtering_sampling(**vars(args)) | |
| if __name__ == "__main__": | |
| main() | |