FastSplatStyler / styletransfer_splat.py
incrl's picture
Initial Upload (attempt 2)
5b557cf verified
import pointCloudToMesh as ply2M
import argparse
import utils
import graph_io as gio
from clusters import *
#from tqdm import tqdm,trange
import splat_mesh_helpers as splt
import clusters as cl
from torch_geometric.data import Data
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
import pointCloudToMesh as plyToMesh
import plotly.graph_objects as go
import pyvista as pv
from time import time
from graph_networks.LinearStyleTransfer_vgg import encoder,decoder
from graph_networks.LinearStyleTransfer_matrix import TransformLayer
from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer
from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4
#import matplotlib.pyplot as plt
def styletransfer_with_filtering_sampling(filename, stylePath, outPath, device = 'cpu', threshold=99.9, samplingRate=1.5, displayPointCloud = False):
print("Running on device:", device)
n = 25
style_ref = utils.loadImage(stylePath, shape=(256*2,256*2))
ratio=.25
depth = 3
pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = splt.splat_unpacker_with_threshold(n, filename, threshold)
time1_start = time()
#plyToMesh.graph_Points(pos3D_Original, torch.clamp(colors_Original, 0, 1))
if samplingRate > 1:
GaussianSamples = int(pos3D_Original.shape[0]*samplingRate)
pos3D, colors = splt.splat_GaussianSuperSampler(pos3D_Original.clone(), colors_Original.clone(), opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(), GaussianSamples)
else:
pos3D, colors = pos3D_Original, colors_Original
#plyToMesh.graph_Points(pos3D, torch.clamp(colors, 0, 1))
#plyToMesh.graph_Points(pos3D_Original, torch.clamp(colors_Original, 0, 1))
time1_end = time()
print("Number of nodes in the graph:", pos3D.shape[0])
print(f"Time taken for Gaussian Super Sampling: {time1_end - time1_start}")
if (displayPointCloud):
#point cloud
point_cloud = pv.PolyData(pos3D.numpy())
# Add colors to the point data
point_cloud.point_data['colors'] = torch.clamp(colors, 0, 3).numpy()
# Plot the point cloud
plotter = pv.Plotter()
plotter.add_points(point_cloud, scalars='colors', rgb=True, point_size=0.05)
plotter.show_axes()
plotter.show()
time2_start = time()
#find normals
normalsNP = ply2M.Estimate_Normals(pos3D, threshold)
normals = torch.from_numpy(normalsNP)
#print("Time to compute normals:", time() - time2_start)
up_vector = torch.tensor([[1,1,1]],dtype=torch.float)
#up_vector = 2*torch.rand((1,3))-1
up_vector = up_vector/torch.linalg.norm(up_vector,dim=1)
pos3D.to(device)
colors.to(device)
normals.to(device)
up_vector.to(device)
# Build initial graph
#edge_index are neighbors of a point, directions are the directions from that point
edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
#directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
# Generate info for downsampled versions of the graph
clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
#clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
time2_end = time()
print(f"Time taken for graph construction: {time2_end - time2_start}")
time3_start = time()
# Make final graph and metadata needed for mapping the result after going through the network
content = Data(x=colors,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
content_meta = Data(pos3D=pos3D)
style,_ = gio.image2Graph(style_ref,depth=3,device=device)
# Load original network
enc_ref = encoder4()
dec_ref = decoder4()
matrix_ref = MulLayer('r41')
enc_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/vgg_r41.pth'))
dec_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/dec_r41.pth'))
matrix_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/r41.pth',map_location=torch.device(device)))
# Copy weights to graph network
enc = encoder(padding_mode="replicate")
dec = decoder(padding_mode="replicate")
matrix = TransformLayer()
with torch.no_grad():
enc.copy_weights(enc_ref)
dec.copy_weights(dec_ref)
matrix.copy_weights(matrix_ref)
content = content.to(device)
style = style.to(device)
enc = enc.to(device)
dec = dec.to(device)
matrix = matrix.to(device)
# Run graph network
with torch.no_grad():
cF = enc(content)
sF = enc(style)
feature,transmatrix = matrix(cF['r41'],sF['r41'],
content.edge_indexes[3],content.selections_list[3],
style.edge_indexes[3],style.selections_list[3],
content.interps_list[3] if hasattr(content,'interps_list') else None)
result = dec(feature,content)
result = result.clamp(0,1)
colors[:, 0:3] = result
time3_end = time()
print(f"Time taken for stylization: {time3_end - time3_start}")
if (displayPointCloud):
#point cloud
point_cloud = pv.PolyData(pos3D.numpy())
# Add colors to the point data
point_cloud.point_data['colors'] = torch.clamp(colors, 0, 3).numpy()
# Plot the point cloud
plotter = pv.Plotter()
plotter.add_points(point_cloud, scalars='colors', rgb=True, point_size=0.25)
plotter.show_axes()
plotter.show()
time4_start = time()
#create the interpolator
interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu())
results_OriginalNP = interp2(pos3D_Original)
results_OriginalNP64 = torch.from_numpy(results_OriginalNP)
results_Original = results_OriginalNP64.to(torch.float32)
colors_and_opacity_Original = torch.cat((results_Original, opacity_Original.unsqueeze(1)), dim=1)
time4_end = time()
print(f"Time taken for interpolation: {time4_end - time4_start}")
# Save/show result
splt.splat_save(pos3D_Original.numpy(), scales_Original.numpy(), rots_Original.numpy(), colors_and_opacity_Original.numpy(), outPath, fileType)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"filename",
type=str,
default=''
)
parser.add_argument(
"--stylePath",
type=str,
default="style_ims/style0.jpg"
)
parser.add_argument(
"--outPath",
type=str,
default='output.splat'
)
parser.add_argument(
"--device",
default= 0 if torch.cuda.is_available() else "cpu",
choices=list(range(torch.cuda.device_count())) + ["cpu"] or ["cpu"]
)
parser.add_argument(
"--threshold",
type=float,
default=99.8
)
parser.add_argument(
"--samplingRate",
type=float,
default=1.5
)
parser.add_argument(
"--displayPointCloud",
action='store_true'
)
args = parser.parse_args()
styletransfer_with_filtering_sampling(**vars(args))
if __name__ == "__main__":
main()