| import streamlit as st |
| import util |
| import torch |
| import render_util |
| import math |
| from pathlib import Path |
| from models import PosADANet |
| import json |
| import plotly.graph_objects as go |
| import gdown |
|
|
|
|
| point_color = "rgb(30, 20, 160)" |
| FILE_PC_KEY = 'File' |
| DEFAULT_COLOR = '#E1E1E1' |
|
|
|
|
| @st.cache_resource |
| def load_model(path: str, num_controls: int, url: str): |
| """ |
| Load model from memory, or download from drive |
| :param path: path to save/load the pretrained model |
| :param num_controls: length of style/control vector the model requires (6 for regular, 8 for metallic roughness) |
| :param url: google drive url to download the model if its not already downloaded |
| :return: returns the pretrained model |
| """ |
| if not Path(path).exists(): |
| with st.spinner('Downloading Model'): |
| gdown.download(url, path, quiet=False) |
|
|
| model = PosADANet(1, 4, num_controls, padding='zeros', bilinear=True).to(device) |
| model.load_state_dict(torch.load(path, map_location=device)) |
| model.eval() |
|
|
| return model |
|
|
|
|
| def load_dict_data(path: str): |
| """ |
| load a json file |
| :param path: path to json file |
| :return: dict with json data |
| """ |
| with open(path, 'r') as file: |
| data = json.load(file) |
|
|
| return data |
|
|
|
|
| def to_rgb(hex_color: str): |
| """ |
| convert color in hex format to rgb format |
| :param hex_color: color hex string |
| :return: list of three numbers for RGB channels between 0-1 |
| """ |
| h = hex_color.lstrip('#') |
| return [float(int(h[i:i + 2], 16)) / 255 for i in (0, 2, 4)] |
|
|
|
|
| st.title('Z2P - Demo') |
|
|
| device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')) |
|
|
| st.subheader('Settings') |
|
|
| |
| model_data = load_dict_data('models/default_settings.json') |
| pc_data = load_dict_data('point_clouds/default_settings.json') |
|
|
| col1_head, col2_head = st.columns(2) |
| model_key = col2_head.radio( |
| 'Choose Model', |
| model_data.keys()) |
|
|
| pc_key = col2_head.radio( |
| 'Choose Point Cloud', |
| pc_data.keys()) |
|
|
| uploaded_file = col2_head.file_uploader('Upload Your Own Point Cloud (.xyz, .obj)') |
|
|
| if pc_key == FILE_PC_KEY: |
| |
| if uploaded_file is not None: |
| txt = uploaded_file.getvalue().decode("utf-8") |
| pc = util.xyz2tensor(txt, append_normals=True) |
| else: |
| st.warning('Please upload a .xyz or .obj file') |
| st.stop() |
| else: |
| |
| pc = util.read_xyz_file(pc_data[pc_key]['path']) |
|
|
| st.header('Input') |
| col1, col2 = st.columns(2) |
|
|
| |
| col2.subheader("Point Cloud Transformations") |
| scale = col2.slider('Scale', min_value=0.0, max_value=5.0, value=pc_data[pc_key]['scale']) |
| rx = col2.slider('X-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rx']) |
| ry = col2.slider('Y-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['ry']) |
| rz = col2.slider('Z-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rz']) |
| dy = col2.slider('Height', min_value=0, max_value=500, value=pc_data[pc_key]['dy']) |
|
|
| col1.subheader("Input Z-Buffer") |
|
|
| |
| pc = render_util.rotate_pc(pc, rx, ry, rz) |
| trace1 = [go.Scatter3d(x=pc[:, 0], y=pc[:, 1], z=-pc[:, 2], mode="markers", |
| marker=dict( |
| symbol="circle", |
| size=1, |
| color=point_color))] |
| fig = go.Figure(trace1, layout=go.Layout()) |
| col1_head.plotly_chart(fig, use_container_width=True) |
|
|
| |
| zbuffer = render_util.draw_pc(pc, radius=model_data[model_key]['point_radius'], dy=dy, scale=scale) |
|
|
| |
| col1.image(zbuffer / zbuffer.max(), use_column_width=True) |
|
|
| zbuffer: torch.Tensor = torch.from_numpy(zbuffer).float().to(device) |
|
|
| st.header('Result') |
|
|
| len_style = model_data[model_key]['len_style'] |
| |
| model = load_model(model_data[model_key]['path'], len_style, model_data[model_key]['url']) |
| col1, col2 = st.columns(2) |
| col2.subheader('Visualization Controls') |
| zbuffer = zbuffer.unsqueeze(-1).permute(2, 0, 1) |
| zbuffer: torch.Tensor = zbuffer.float().to(device).unsqueeze(0) |
|
|
| style = torch.zeros(len_style, dtype=zbuffer.dtype, device=device) |
|
|
| |
| hex_color = col2.color_picker('Pick A Color', DEFAULT_COLOR) |
| style[0], style[1], style[2] = to_rgb(hex_color) |
| style[:3] = style[:3].clip(0.0, 0.9) |
|
|
| |
| style[3] = col2.slider('Light Radius', min_value=-1.0, max_value=1.0, value=0.0) |
| style[4] = col2.slider('Light Phi', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) |
| style[5] = col2.slider('Light Theta', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) |
|
|
| |
| if len_style == 8: |
| style[6] = col2.slider('Mettalic', min_value=0.0, max_value=1.0, value=0.5) |
| style[7] = col2.slider('Roughness', min_value=0.0, max_value=1.0, value=0.5) |
|
|
| style = style.unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| generated = model(zbuffer.float(), style) |
|
|
| |
| |
| generated = util.embed_color(generated.detach(), style[:, :3], box_size=50) |
| rendered = generated[0].permute(1, 2, 0).cpu().numpy() |
|
|
| |
| col1.image(rendered.clip(0, 1), use_column_width=True) |
|
|