diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..bc40165288b30bd5740b8971709a3279d53fdae7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +demo_images/panorama/_2h2HQiO_LScptNqIlv26g,41.876157,-87.656242,.jpg filter=lfs diff=lfs merge=lfs -text +demo_images/panorama/_cF0n90rix962-yUqP13Ow,41.870943,-87.647100,.jpg filter=lfs diff=lfs merge=lfs -text +demo_images/panorama/_MubtGZKtddAL0UcHHf5Cg,41.861737,-87.644142,.jpg filter=lfs diff=lfs merge=lfs -text +demo_images/panorama/_U8e_qINb8isqDbFusZadg,41.867593,-87.673364,.jpg filter=lfs diff=lfs merge=lfs -text +demo_images/panorama/default_panorama.jpg filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_1.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_10.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_2.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_3.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_4.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_5.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_6.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_7.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_8.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/sat_demo_9.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.564866521505394_-122.33116389229833.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.57109989810551_-122.30149079418568.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.57372447772661_-122.30489590380517.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.574380622631885_-122.32192145190258.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.57503676753716_-122.33456900191781.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.576020984895074_-122.31657056535768.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.5779894196109_-122.31608412112634.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.58356665130574_-122.2873839114764.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.59012810035849_-122.32240789613394.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.59111231771641_-122.28981613263318.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.59176846262168_-122.29419413071537.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.595377259600696_-122.32435367305936.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.59767376676916_-122.33651477884322.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.6006264188429_-122.29419413071537.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.60390714336928_-122.31559767689498.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.616373896569506_-122.31024679035008.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.616373896569506_-122.3131654557382.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.618014258832694_-122.29857212879756.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.61867040373797_-122.3024636826484.png filter=lfs diff=lfs merge=lfs -text +demo_images/satellite/satellite_47.62227920071699_-122.29176190955859.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index a10f6af2511087c0d7bc48f7dcb5477d98fb934e..54bdb0111fb39b5bb7e64bf71717bc793e4019bd 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,35 @@ --- title: Sat3DGen -emoji: 🔥 +emoji: 🛰️ colorFrom: blue -colorTo: blue +colorTo: green sdk: gradio -sdk_version: 6.13.0 +sdk_version: "4.44.1" app_file: app.py pinned: false license: mit -short_description: '[ICLR 2026] Sat3DGen: Comprehensive Street-Level 3D Scene Ge' +models: + - qian43/Sat3DGen +suggested_hardware: t4-small --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +## Sat3DGen – Street-Level 3D Scene Generation from Satellite Images + +**[ICLR 2026]** Generate street-level 3D scenes from a single satellite image. + +- 📄 [Paper (OpenReview)](https://openreview.net/forum?id=E7JzkZCofa) +- 🌐 [Project Page](https://qianmingduowan.github.io/Sat3DGen_project_page/) +- 💻 [GitHub](https://github.com/qianmingduowan/Sat3DGen) +- 🤗 [Model](https://huggingface.co/qian43/Sat3DGen) + +### Features + +- **3D Mesh Generation**: Upload a satellite image → get a downloadable `.obj` mesh with in-browser 3D preview. +- **Video Rendering**: Select a satellite image and sky panorama → render a walkthrough video along a trajectory. + +### Usage + +Upload a satellite image or select one from the examples, then click "Generate". +The model weights are loaded automatically from HuggingFace. + +> **Note**: GPU hardware is recommended for reasonable inference speed. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..765e92fb630210175d27ca1e21f9a7d0b45d1616 --- /dev/null +++ b/app.py @@ -0,0 +1,724 @@ +"""Sat3DGen Gradio Demo. + +Two-step interactive demo: + 1. Upload a satellite image -> generate and visualize a 3D mesh. + 2. Select a demo image with a pre-generated trajectory -> render panorama + perspective video. +""" + +import csv +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import List, Optional, Tuple + +import cv2 +import gradio as gr +import numpy as np +import open3d as o3d +import torch +import torchvision.transforms as T +import trimesh +from PIL import Image + +from source.generator import Sat3DGen +from source.rendering.transform_perspective import compose_rotmat + +# --------------------------------------------------------------------------- +# Global state +# --------------------------------------------------------------------------- +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +MODEL: Optional[Sat3DGen] = None +PATCH_SIZE: int = 16 +SAT_TRANSFORM = None +RESULTS_DIR = Path("./results/gradio_demo") +TRAJECTORY_PREVIEW_SIZE = 256 +DEFAULT_SKY_FILENAMES = ( + "default_panorama.jpg", + "default_panorama.png", + "default_panorama.jpeg", + "default_demo_panorama.jpg", + "default_demo_panorama.png", + "default_demo_panorama.jpeg", + "default_sky.jpg", + "default_sky.png", + "default_sky.jpeg", +) +RESULTS_DIR.mkdir(parents=True, exist_ok=True) + + +HUGGINGFACE_REPO = "qian43/Sat3DGen" + +def load_model(checkpoint_path: str = "checkpoints"): + """Load the Sat3DGen model (singleton). + + Resolution order: + 1. Local *checkpoint_path* directory (if it contains model files). + 2. HuggingFace Hub repo ``qian43/Sat3DGen``. + + When loading from a full checkpoint (local or Hub), the backbone + weights are already included in the safetensors file, so the + standalone DINOv3 download is skipped automatically. + """ + global MODEL, PATCH_SIZE, SAT_TRANSFORM + + if MODEL is not None: + return + + model_path: str | None = None + checkpoint_path_obj = Path(checkpoint_path) + if (checkpoint_path_obj / "config.json").exists(): + model_path = str(checkpoint_path_obj) + elif (checkpoint_path_obj / "vqmodel_ema").exists(): + model_path = str(checkpoint_path_obj / "vqmodel_ema") + elif (checkpoint_path_obj / "vqmodel").exists(): + model_path = str(checkpoint_path_obj / "vqmodel") + + if model_path is None: + model_path = HUGGINGFACE_REPO + print(f"Local checkpoint not found at '{checkpoint_path}', loading from HuggingFace: {HUGGINGFACE_REPO}") + + # Skip redundant backbone weight download – from_pretrained will + # overwrite all parameters from the safetensors file anyway. + Sat3DGen._skip_backbone_weights = True + print(f"Loading model from {model_path} ...") + MODEL = Sat3DGen.from_pretrained(model_path).to(DEVICE) + Sat3DGen._skip_backbone_weights = False + MODEL.eval() + PATCH_SIZE = MODEL.unet_model.patch_size if hasattr(MODEL.unet_model, "patch_size") else 16 + SAT_TRANSFORM = T.Compose([ + T.Resize((PATCH_SIZE * 16, PATCH_SIZE * 16), interpolation=Image.BICUBIC), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + print("Model loaded successfully.") + + +# --------------------------------------------------------------------------- +# Utility helpers (adapted from single_image_inference.py) +# --------------------------------------------------------------------------- + +def save_obj(vertices: np.ndarray, faces: np.ndarray, colors: np.ndarray, filepath: str): + vertices = vertices @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + faces = faces[:, [2, 1, 0]] + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(vertices) + mesh.triangles = o3d.utility.Vector3iVector(faces) + mesh.vertex_colors = o3d.utility.Vector3dVector(colors / 255.0) + o3d.io.write_triangle_mesh(filepath, mesh, write_vertex_normals=False) + + +def position_to_c2w(position: Tuple[float, float, float]) -> torch.Tensor: + rotation = compose_rotmat(0, 0, 0) + pos = np.array(position, dtype=np.float32) + pos[0] *= -1 + pos = pos[[1, 0, 2]] + c2w = np.eye(4, dtype=np.float32) + c2w[:3, :3] = np.array(rotation, dtype=np.float32) + c2w[:3, 3] = pos + return torch.from_numpy(c2w).unsqueeze(0).to(DEVICE) + + +def build_intrinsics() -> torch.Tensor: + fovx, fovy = 120, 120 + fx = 0.5 * 256 / np.tan(0.5 * fovx / 180.0 * np.pi) + fy = 0.5 * 256 / np.tan(0.5 * fovy / 180.0 * np.pi) + cx = (256 - 1) / 2.0 + cy = (256 - 1) / 2.0 + intrinsics = np.array([[fx / 2, 0, cx / 2], [0, fy / 2, cy / 2], [0, 0, 1]], dtype=np.float32) + return torch.from_numpy(intrinsics).unsqueeze(0).to(DEVICE) + + +def tensor_to_numpy_rgb(tensor: torch.Tensor) -> np.ndarray: + """Convert a [1, C, H, W] or [C, H, W] tensor in [0, 1] to a uint8 RGB numpy array.""" + img = tensor.detach().cpu().clamp(0, 1) + if img.dim() == 4: + img = img.squeeze(0) + return (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + + +def get_pano_rgb(output) -> torch.Tensor: + if hasattr(output.str_output, "sr_image"): + return output.str_output.sr_image + return output.str_output.image_raw_compo + + +def get_per_rgb(output) -> torch.Tensor: + if hasattr(output.per_output, "sr_image"): + return output.per_output.sr_image + return output.per_output.image_raw_compo + + +def make_histo(grd_img_path: str) -> torch.Tensor: + grd_img = Image.open(grd_img_path).convert("RGB").resize((512, 128)) + grd_img = T.ToTensor()(grd_img).unsqueeze(0).float().to(DEVICE) + + # Derive the sky-mask path by replacing only the parent directory name, + # keeping the filename intact (just switching extension to .png). + grd_path = Path(grd_img_path) + parent_name = grd_path.parent.name + if parent_name in ("streetview", "panorama"): + mask_dir = grd_path.parent.parent / "pano_sky_mask" + mask_img_path = str(mask_dir / grd_path.with_suffix(".png").name) + else: + raise ValueError(f"Cannot infer sky-mask path from {grd_img_path}") + + mask_img = Image.open(mask_img_path).convert("L").resize((512, 128), Image.NEAREST) + mask_img = T.ToTensor()(mask_img).unsqueeze(0).float().to(DEVICE) + + sky_image = (grd_img * mask_img).mul(2).sub(1) + sky_image = sky_image.detach().cpu().numpy() + + from source.sky_histogram import compute_sky_histogram + + histo_sky = torch.from_numpy( + compute_sky_histogram(sky_image[0], hist_range=(-1, 1)) + ).unsqueeze(0).float().to(DEVICE) + return histo_sky + + +def read_trajectory_from_csv(csv_path: str, sat_image_size: int) -> Tuple[List[Tuple[float, float, float]], np.ndarray]: + """Read a pre-generated trajectory .csv file (format: w,h,angle). + + Returns: + positions: list of (x_norm, y_norm, z) in [-1, 1] range for rendering + pixel_coords: Nx2 array of pixel coordinates for visualization + """ + half = sat_image_size / 2 + positions = [] + pixel_coords = [] + with open(csv_path, "r") as f: + reader = csv.DictReader(f) + for row in reader: + px = float(row["w"]) + py = float(row["h"]) + pixel_coords.append((px, py)) + positions.append(((py - half) / half, (px - half) / half, -0.85)) + return positions, np.array(pixel_coords, dtype=np.float32) + + +def draw_trajectory_on_satellite( + sat_image_pil: Image.Image, + pixel_coords: np.ndarray, + active_index: Optional[int] = None, +) -> np.ndarray: + """Draw trajectory on satellite image with glow effect (matching demo_inference style).""" + sat_frame = np.array(sat_image_pil.convert("RGB")) + + if len(pixel_coords) >= 2: + # White outline pass (thicker, drawn first) + for idx in range(len(pixel_coords) - 1): + pt1 = tuple(np.round(pixel_coords[idx]).astype(int)) + pt2 = tuple(np.round(pixel_coords[idx + 1]).astype(int)) + cv2.line(sat_frame, pt1, pt2, (255, 255, 255), 3, cv2.LINE_AA) + # Colored line pass (thinner, on top) + for idx in range(len(pixel_coords) - 1): + pt1 = tuple(np.round(pixel_coords[idx]).astype(int)) + pt2 = tuple(np.round(pixel_coords[idx + 1]).astype(int)) + cv2.line(sat_frame, pt1, pt2, (255, 80, 80), 2, cv2.LINE_AA) + + if active_index is not None and len(pixel_coords) > 0: + coord = pixel_coords[min(active_index, len(pixel_coords) - 1)] + px, py = int(round(coord[0])), int(round(coord[1])) + # Outer glow via alpha blending + overlay = sat_frame.copy() + cv2.circle(overlay, (px, py), 12, (0, 255, 100), -1, cv2.LINE_AA) + sat_frame = cv2.addWeighted(sat_frame, 0.7, overlay, 0.3, 0) + # Solid inner circle + white ring + cv2.circle(sat_frame, (px, py), 6, (0, 255, 100), -1, cv2.LINE_AA) + cv2.circle(sat_frame, (px, py), 7, (255, 255, 255), 2, cv2.LINE_AA) + + return sat_frame + + +def build_trajectory_preview(sat_image_pil: Image.Image, pixel_coords: np.ndarray) -> Image.Image: + sat_frame = draw_trajectory_on_satellite(sat_image_pil, pixel_coords) + preview = cv2.resize( + sat_frame, + (TRAJECTORY_PREVIEW_SIZE, TRAJECTORY_PREVIEW_SIZE), + interpolation=cv2.INTER_LINEAR, + ) + return Image.fromarray(preview) + + +def resolve_demo_sky_pairs(demo_dir: Path) -> Tuple[List[Tuple[Path, Path]], Optional[Path]]: + pano_dir = demo_dir / "panorama" + mask_dir = demo_dir / "pano_sky_mask" + if not pano_dir.exists() or not mask_dir.exists(): + return [], None + + mask_lookup = {mask_path.stem: mask_path for mask_path in sorted(mask_dir.glob("*.png"))} + sky_pairs: List[Tuple[Path, Path]] = [] + for pano_path in sorted(pano_dir.glob("*")): + if pano_path.suffix.lower() not in {".jpg", ".jpeg", ".png"}: + continue + mask_path = mask_lookup.get(pano_path.stem) + if mask_path is not None: + sky_pairs.append((pano_path, mask_path)) + + if not sky_pairs: + return [], None + + default_idx = 0 + for idx, (pano_path, _) in enumerate(sky_pairs): + pano_name_lower = pano_path.name.lower() + pano_stem_lower = pano_path.stem.lower() + if pano_name_lower in DEFAULT_SKY_FILENAMES or "default" in pano_stem_lower: + default_idx = idx + break + + ordered_pairs = [sky_pairs[default_idx], *sky_pairs[:default_idx], *sky_pairs[default_idx + 1 :]] + return ordered_pairs, ordered_pairs[0][0] + + +# --------------------------------------------------------------------------- +# Step 1: Satellite Image → 3D Mesh +# --------------------------------------------------------------------------- + +def generate_mesh(sat_image_pil: Image.Image, mesh_resolution: int = 256, progress=gr.Progress()): + """Generate a 3D mesh from a satellite image.""" + if sat_image_pil is None: + raise gr.Error("Please upload a satellite image first.") + + print("[generate_mesh] >>> Start") + load_model() + print("[generate_mesh] Model loaded") + + progress(0.1, desc="Preprocessing satellite image...") + print("[generate_mesh] Preprocessing satellite image...") + sat_input = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE) + + progress(0.3, desc="Generating triplane features...") + print("[generate_mesh] Generating triplane features...") + with torch.no_grad(): + triplane = MODEL.from_sat_to_triplane(sat_input) + print("[generate_mesh] Triplane generated successfully") + + progress(0.5, desc="Extracting 3D mesh (this may take a moment)...") + print(f"[generate_mesh] Extracting 3D mesh (resolution={mesh_resolution})...") + with torch.no_grad(): + vertices, faces, vertex_colors = MODEL.extract_mesh(triplane, mesh_resolution=mesh_resolution) + print(f"[generate_mesh] Mesh extracted: {vertices.shape[0]} vertices, {faces.shape[0]} faces") + + vertices = vertices[:, [1, 2, 0]] + + # Save mesh + mesh_path = str(RESULTS_DIR / "mesh.obj") + save_obj(vertices, faces, vertex_colors, mesh_path) + print(f"[generate_mesh] OBJ saved to {mesh_path}") + + # Also save triplane to state for Step 2 + state = {"triplane": triplane, "sat_image": sat_image_pil} + + progress(0.9, desc="Preparing 3D visualization...") + print("[generate_mesh] Converting OBJ → GLB for 3D preview...") + + # Create a glb file for Gradio's Model3D component. + # Use a tempfile so Gradio can reliably serve it via its file cache. + import tempfile, shutil + glb_path_local = str(RESULTS_DIR / "mesh.glb") + mesh_trimesh = trimesh.load(mesh_path, process=False) + # Ensure we have a single Trimesh (not a Scene) with vertex normals, + # otherwise Chrome's WebGL renderer shows a blank canvas. + if isinstance(mesh_trimesh, trimesh.Scene): + geometries = list(mesh_trimesh.geometry.values()) + if geometries: + mesh_trimesh = trimesh.util.concatenate(geometries) + else: + raise gr.Error("Failed to load mesh geometry.") + if not hasattr(mesh_trimesh, 'vertex_normals') or mesh_trimesh.vertex_normals is None or len(mesh_trimesh.vertex_normals) == 0: + mesh_trimesh.vertex_normals # triggers auto-computation + print(f"[generate_mesh] Mesh has {len(mesh_trimesh.vertices)} verts, {len(mesh_trimesh.faces)} faces, normals: {mesh_trimesh.vertex_normals.shape}") + mesh_trimesh.export(glb_path_local, file_type="glb") + print(f"[generate_mesh] GLB saved to {glb_path_local} ({os.path.getsize(glb_path_local)} bytes)") + + tmp_glb = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) + shutil.copy2(glb_path_local, tmp_glb.name) + tmp_glb.close() + print(f"[generate_mesh] GLB copied to temp file: {tmp_glb.name}") + + progress(1.0, desc="Done!") + print("[generate_mesh] <<< 3D mesh generated successfully!") + return tmp_glb.name, mesh_path, state + + +def download_mesh(mesh_path: str): + """Return the mesh file for download.""" + if mesh_path and os.path.exists(mesh_path): + return mesh_path + return None + + +# --------------------------------------------------------------------------- +# Step 2: Trajectory → Panorama + Perspective Video +# --------------------------------------------------------------------------- + +def render_trajectory_video( + sat_image_pil: Image.Image, + trajectory_csv_path: str, + sky_path: str, + progress=gr.Progress(), +): + """Render panorama and perspective views along a pre-generated trajectory. + + Layout per frame: + Top row: satellite image (with camera marker) | panorama RGB + Bottom row: 4 perspective views in a horizontal row (left, front, right, back) + """ + print("[render_trajectory_video] >>> Start") + load_model() + + sat_size = sat_image_pil.size[0] + positions, pixel_coords = read_trajectory_from_csv(trajectory_csv_path, sat_size) + if len(positions) == 0: + raise gr.Error(f"Trajectory file is empty: {trajectory_csv_path}") + print(f"[render_trajectory_video] Loaded {len(positions)} positions from {trajectory_csv_path}") + + progress(0.1, desc="Extracting triplane features...") + sat_tensor = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE) + with torch.no_grad(): + triplane = MODEL.from_sat_to_triplane(sat_tensor) + + progress(0.2, desc="Preparing sky condition...") + sky_hist = make_histo(sky_path) + with torch.no_grad(): + w_sky = MODEL.w_sky_prepare(sky_hist) + sky_feature_2d = MODEL.w_sky2sky_feature_2D(w_sky, sky_hist) + + progress(0.25, desc="Rendering views along trajectory...") + intrinsics = build_intrinsics() + yaw_values = [0, -90, 90, 180] + + video_dir = RESULTS_DIR / "video_frames" + if video_dir.exists(): + shutil.rmtree(video_dir) + video_dir.mkdir(parents=True, exist_ok=True) + + total_positions = len(positions) + for idx, position in enumerate(positions): + progress(0.25 + 0.6 * idx / total_positions, desc=f"Rendering frame {idx + 1}/{total_positions}...") + if idx % 10 == 0 or idx == total_positions - 1: + print(f"[render_trajectory_video] Rendering frame {idx + 1}/{total_positions}...") + + c2w = position_to_c2w(position) + c2w[:, :3, 3] = c2w[:, :3, 3] * MODEL.position_scale_factor + + with torch.no_grad(): + pano_result = MODEL.from_3D_to_results( + triplane, + c2w=c2w, + w_sky=w_sky, + sky_feature_2D=sky_feature_2d, + syn_pano=True, + ) + pano_rgb = tensor_to_numpy_rgb(get_pano_rgb(pano_result)) + + per_views = [] + for yaw in yaw_values: + c2w_per = c2w.clone() + c2w_per[:, :3, :3] = torch.from_numpy(compose_rotmat(0, 0, yaw)).unsqueeze(0).to(DEVICE) + per_result = MODEL.from_3D_to_results( + triplane, + c2w=c2w_per, + w_sky=w_sky, + intrinsics=intrinsics, + sky_feature_2D=sky_feature_2d, + syn_pano=False, + syn_per=True, + ) + per_rgb = tensor_to_numpy_rgb(get_per_rgb(per_result)) + per_views.append(per_rgb) + + # --- Satellite image with camera position marker --- + sat_frame = draw_trajectory_on_satellite(sat_image_pil, pixel_coords, active_index=idx) + + # --- Compose frame --- + # Top row: satellite (square) | panorama RGB + pano_h, pano_w = pano_rgb.shape[:2] + sat_resized = cv2.resize(sat_frame, (pano_h, pano_h)) + top_row = np.concatenate([sat_resized, pano_rgb], axis=1) + + # Bottom row: 4 perspective views in a horizontal row (left, front, right, back) + # Flip back view for consistency + per_back = cv2.flip(per_views[3], 1) + per_row = np.concatenate([per_views[1], per_views[0], per_views[2], per_back], axis=1) + + # Resize bottom row to match top row width + top_width = top_row.shape[1] + per_row_h = int(per_row.shape[0] * top_width / per_row.shape[1]) + per_row_resized = cv2.resize(per_row, (top_width, per_row_h)) + + composed = np.concatenate([top_row, per_row_resized], axis=0) + + frame_path = video_dir / f"{idx:04d}.png" + cv2.imwrite(str(frame_path), cv2.cvtColor(composed, cv2.COLOR_RGB2BGR)) + + progress(0.9, desc="Encoding video...") + print("[render_trajectory_video] All frames rendered, encoding video with ffmpeg...") + video_path = str(RESULTS_DIR / "trajectory_video.mp4") + ffmpeg_path = shutil.which("ffmpeg") + if ffmpeg_path is None: + raise gr.Error("ffmpeg not found. Please install ffmpeg to generate videos.") + + subprocess.run([ + ffmpeg_path, "-y", "-framerate", "5", + "-i", str(video_dir / "%04d.png"), + "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", + "-c:v", "libx264", "-pix_fmt", "yuv420p", + video_path, + ], check=True, capture_output=True) + + print(f"[render_trajectory_video] Video saved to {video_path}") + progress(1.0, desc="Done!") + return video_path + + + +# --------------------------------------------------------------------------- +# Gradio UI +# --------------------------------------------------------------------------- + +def build_demo(): + + # Find sample images from demo directory + demo_dir = Path(__file__).resolve().parent / "demo_images" + sample_sat_images = sorted((demo_dir / "satellite").glob("*.png")) if (demo_dir / "satellite").exists() else [] + sample_sat_images_with_csv = [p for p in sample_sat_images if p.with_suffix(".csv").exists()] + sample_sky_pairs, default_sky_path = resolve_demo_sky_pairs(demo_dir) + + with gr.Blocks(title="Sat3DGen Demo", theme=gr.themes.Soft()) as demo: + gr.Markdown( + """ + ## Sat3DGen: 3D Scene Generation from Satellite Imagery + + Upload a satellite image to **generate a 3D mesh** or **render a walkthrough video**. + """ + ) + + # Shared state + inference_state = gr.State(value=None) + mesh_file_path = gr.State(value=None) + + # ---- 3D Mesh Generation ---- + with gr.Tab("3D Mesh Generation"): + with gr.Row(): + with gr.Column(scale=1): + sat_input = gr.Image( + label="Upload Satellite Image", + type="pil", + height=400, + ) + mesh_resolution_slider = gr.Slider( + minimum=128, maximum=512, value=128, step=64, + label="Mesh Resolution (voxel size)", + ) + generate_button = gr.Button("🚀 Generate 3D Mesh", variant="primary", size="lg") + + with gr.Column(scale=2): + mesh_viewer = gr.Model3D(label="3D Mesh Preview", height=500) + download_button = gr.DownloadButton("💾 Download Mesh (.obj)", variant="secondary") + + if sample_sat_images: + gr.Markdown("### Sample Images") + gr.Examples( + examples=[[str(p)] for p in sample_sat_images], + inputs=[sat_input], + label="Click to load a sample satellite image", + examples_per_page=len(sample_sat_images), + ) + + gr.Markdown( + "⚠️ **Note:** The 3D mesh preview may show slight color distortion. " + "The cause is currently under investigation." + ) + + generate_button.click( + fn=generate_mesh, + inputs=[sat_input, mesh_resolution_slider], + outputs=[mesh_viewer, mesh_file_path, inference_state], + ) + mesh_file_path.change( + fn=download_mesh, + inputs=[mesh_file_path], + outputs=[download_button], + ) + # ---- Video Rendering ---- + with gr.Tab("Video Rendering"): + # Hidden state to track the resolved trajectory .csv path + trajectory_csv_state = gr.State(value=None) + sky_path_state = gr.State(value=str(default_sky_path) if default_sky_path is not None else None) + + def load_sat_from_gallery(evt: gr.SelectData): + """Load selected satellite image and check for a same-name trajectory .csv.""" + if evt.index is None or evt.index >= len(sample_sat_images_with_csv): + return None, None, "No image selected.", None + sat_path = sample_sat_images_with_csv[evt.index] + sat_pil = Image.open(str(sat_path)) + csv_path = sat_path.with_suffix(".csv") + if csv_path.exists(): + status_msg = f"✅ Trajectory found: `{csv_path.name}`" + _, pixel_coords = read_trajectory_from_csv(str(csv_path), sat_pil.size[0]) + preview = build_trajectory_preview(sat_pil, pixel_coords) + return sat_pil, str(csv_path), status_msg, preview + status_msg = ( + f"⚠️ No trajectory file found. " + f"Please pre-generate a trajectory and save it as " + f"`{csv_path.name}` in `{sat_path.parent}/` using:\n\n" + f"```\npython inference/make_trajectory.py " + f"--input_img_path {sat_path} --save_same_name\n```" + ) + return sat_pil, None, status_msg, None + + def on_sat_upload(sat_image_pil): + """When user uploads a custom satellite image, no same-name trajectory CSV is available.""" + if sat_image_pil is None: + return None, "No image uploaded.", None + return None, ( + "⚠️ For uploaded images, you need a **trajectory .csv** file with the same name " + "as your satellite image (e.g. `my_image.csv` for `my_image.png`).\n\n" + "You can generate one interactively using either:\n\n" + "- **Jupyter Notebook** (recommended): `inference/make_trajectory.ipynb`\n" + "- **Command line**: " + "`python inference/make_trajectory.py --input_img_path --save_same_name`\n\n" + "If you used the command line **without** `--save_same_name`, " + "the CSV is saved under `results//pixels.csv`. " + "You will need to **copy** it next to your satellite image with the same base name " + "(e.g. copy to `demo_images/satellite/my_image.csv`)." + ), None + + def load_sky_from_gallery(evt: gr.SelectData): + """Select one demo panorama street image. The first entry is the default.""" + if not sample_sky_pairs: + return None, None, "No demo panorama street image is available." + if evt.index is None or evt.index >= len(sample_sky_pairs): + sky_path = default_sky_path + else: + sky_path = sample_sky_pairs[evt.index][0] + default_suffix = " (Default)" if sky_path == default_sky_path else "" + status_msg = ( + f"Selected demo panorama: `{sky_path.name}`{default_suffix}\n\n" + f"If you do not choose another one, this image will be used." + ) + return str(sky_path), str(sky_path), status_msg + + def render_video_from_state(sat_image, csv_path, sky_path, progress=gr.Progress()): + """Render video using the pre-generated trajectory CSV.""" + if sat_image is None: + raise gr.Error("Please select or upload a satellite image first.") + if csv_path is None or not Path(csv_path).exists(): + raise gr.Error( + "No trajectory CSV found. Please pre-generate a trajectory using: " + "python inference/make_trajectory.py --input_img_path --save_same_name" + ) + resolved_sky_path = sky_path or (str(default_sky_path) if default_sky_path is not None else None) + if resolved_sky_path is None or not Path(resolved_sky_path).exists(): + raise gr.Error("No valid demo panorama is available. Please add one under demo_images/panorama.") + return render_trajectory_video(sat_image, csv_path, resolved_sky_path, progress) + + # ===== Main layout ===== + with gr.Row(equal_height=False): + # Left column: satellite image selection + with gr.Column(scale=1): + sat_input_video = gr.Image( + label="Upload Satellite Image", + type="pil", + height=300, + ) + trajectory_status = gr.Markdown(value="Select a demo image or upload your own.") + selected_sky_preview = gr.Image( + label="Selected Demo Panorama", + value=str(default_sky_path) if default_sky_path is not None else None, + height=180, + ) + default_sky_message = "No demo panorama street image is available." + if default_sky_path is not None: + default_sky_message = ( + f"Default demo panorama: `{default_sky_path.name}`\n\n" + "If you do not select another demo panorama, this one will be used." + ) + sky_status = gr.Markdown(value=default_sky_message) + render_button = gr.Button("🎬 Render Video", variant="primary", size="lg") + + # Middle column: trajectory preview + with gr.Column(scale=1): + trajectory_preview = gr.Image(label="Trajectory Preview", height=300) + + # Right column: video output + with gr.Column(scale=2): + video_output = gr.Video(label="Rendered Video", height=500) + + # ===== Sample Satellite Images Gallery (only those with a trajectory CSV) ===== + if sample_sat_images_with_csv: + gr.Markdown("### 🛰️ Sample Satellite Images — click to load") + sat_gallery = gr.Gallery( + value=[str(p) for p in sample_sat_images_with_csv], + label="Sample Satellite Images (with trajectory)", + columns=5, + rows=1, + height=120, + object_fit="cover", + allow_preview=False, + ) + sat_gallery.select( + fn=load_sat_from_gallery, + inputs=None, + outputs=[sat_input_video, trajectory_csv_state, trajectory_status, trajectory_preview], + ) + + if sample_sky_pairs: + gr.Markdown( + "### 🌤️ Demo Panorama Street Images — the first one is the default\n\n" + "The panorama image and its corresponding sky mask are used to extract a " + "**sky region color histogram**, which serves as a **lighting condition hint** " + "during street-view rendering. This only affects the appearance (illumination/color tone) " + "of the rendered views — it does **not** alter the underlying 3D NeRF geometry." + ) + sky_gallery = gr.Gallery( + value=[ + ( + str(pano_path), + f"{pano_path.name} (Default)" if pano_path == default_sky_path else pano_path.name, + ) + for pano_path, _ in sample_sky_pairs + ], + label="Demo Panorama Street Images", + columns=5, + rows=1, + height=120, + object_fit="cover", + allow_preview=False, + ) + sky_gallery.select( + fn=load_sky_from_gallery, + inputs=None, + outputs=[sky_path_state, selected_sky_preview, sky_status], + ) + + # When user uploads a custom image + sat_input_video.upload( + fn=on_sat_upload, + inputs=[sat_input_video], + outputs=[trajectory_csv_state, trajectory_status, trajectory_preview], + ) + + render_button.click( + fn=render_video_from_state, + inputs=[sat_input_video, trajectory_csv_state, sky_path_state], + outputs=[video_output], + ) + return demo + + +if __name__ == "__main__": + demo = build_demo() + port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) + demo.launch( + server_name="0.0.0.0", + server_port=port, + share=False, + allowed_paths=[ + str(Path(__file__).resolve().parent / "demo_images"), + str(Path(__file__).resolve().parent / "results"), + ], + ) diff --git a/demo_images/pano_sky_mask/_2h2HQiO_LScptNqIlv26g,41.876157,-87.656242,.png b/demo_images/pano_sky_mask/_2h2HQiO_LScptNqIlv26g,41.876157,-87.656242,.png new file mode 100644 index 0000000000000000000000000000000000000000..70daf4bc4b4b108034fa2a28e3700d89f6619654 Binary files /dev/null and b/demo_images/pano_sky_mask/_2h2HQiO_LScptNqIlv26g,41.876157,-87.656242,.png differ diff --git a/demo_images/pano_sky_mask/_MubtGZKtddAL0UcHHf5Cg,41.861737,-87.644142,.png b/demo_images/pano_sky_mask/_MubtGZKtddAL0UcHHf5Cg,41.861737,-87.644142,.png new file mode 100644 index 0000000000000000000000000000000000000000..a70d94244fb86ea9ae8164e21d4129c85351412f Binary files /dev/null and b/demo_images/pano_sky_mask/_MubtGZKtddAL0UcHHf5Cg,41.861737,-87.644142,.png differ diff --git a/demo_images/pano_sky_mask/_U8e_qINb8isqDbFusZadg,41.867593,-87.673364,.png b/demo_images/pano_sky_mask/_U8e_qINb8isqDbFusZadg,41.867593,-87.673364,.png new file mode 100644 index 0000000000000000000000000000000000000000..43f76fcf5e5aa906f4d20f3d56d277adbc34f17d Binary files /dev/null and b/demo_images/pano_sky_mask/_U8e_qINb8isqDbFusZadg,41.867593,-87.673364,.png differ diff --git a/demo_images/pano_sky_mask/_cF0n90rix962-yUqP13Ow,41.870943,-87.647100,.png b/demo_images/pano_sky_mask/_cF0n90rix962-yUqP13Ow,41.870943,-87.647100,.png new file mode 100644 index 0000000000000000000000000000000000000000..38feda4b06acd2889846c18e8e8c07cba5f01c7b Binary files /dev/null and b/demo_images/pano_sky_mask/_cF0n90rix962-yUqP13Ow,41.870943,-87.647100,.png differ diff --git a/demo_images/pano_sky_mask/default_panorama.png b/demo_images/pano_sky_mask/default_panorama.png new file mode 100644 index 0000000000000000000000000000000000000000..cdda8d9058388635d17d1dcc4e966a1dcbfbe292 Binary files /dev/null and b/demo_images/pano_sky_mask/default_panorama.png differ diff --git a/demo_images/panorama/_2h2HQiO_LScptNqIlv26g,41.876157,-87.656242,.jpg b/demo_images/panorama/_2h2HQiO_LScptNqIlv26g,41.876157,-87.656242,.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd2946566928cf5cf33aede9bbbe76ae6f1fe960 --- /dev/null +++ b/demo_images/panorama/_2h2HQiO_LScptNqIlv26g,41.876157,-87.656242,.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7e42b80e5ea3f60c7048e7f796d5999acea08e888a0de19155fee4007ce599b +size 758305 diff --git a/demo_images/panorama/_MubtGZKtddAL0UcHHf5Cg,41.861737,-87.644142,.jpg b/demo_images/panorama/_MubtGZKtddAL0UcHHf5Cg,41.861737,-87.644142,.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7b5c2dad70b2409c23ae8ba401d96079a78bde1f --- /dev/null +++ b/demo_images/panorama/_MubtGZKtddAL0UcHHf5Cg,41.861737,-87.644142,.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99f537d09b5e0373947499d78c3c75553df12f3a995c579d42619b2ecec2c8c6 +size 546317 diff --git a/demo_images/panorama/_U8e_qINb8isqDbFusZadg,41.867593,-87.673364,.jpg b/demo_images/panorama/_U8e_qINb8isqDbFusZadg,41.867593,-87.673364,.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1aebec85f10dfbd63ddbc01df11cd8bf725f4be9 --- /dev/null +++ b/demo_images/panorama/_U8e_qINb8isqDbFusZadg,41.867593,-87.673364,.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:339379e8d253fbc8ef2b5f255ccdd269b3b9951b0e91bb4b1ab9769c3303a05a +size 681308 diff --git a/demo_images/panorama/_cF0n90rix962-yUqP13Ow,41.870943,-87.647100,.jpg b/demo_images/panorama/_cF0n90rix962-yUqP13Ow,41.870943,-87.647100,.jpg new file mode 100644 index 0000000000000000000000000000000000000000..423ffb70146e1daeee42c3a972b29f246c690206 --- /dev/null +++ b/demo_images/panorama/_cF0n90rix962-yUqP13Ow,41.870943,-87.647100,.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74470161cc795e013159110e5c6f899102441b6898d35295f7f332f67a080b1e +size 588386 diff --git a/demo_images/panorama/default_panorama.jpg b/demo_images/panorama/default_panorama.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fbeece22e2cd29d1900652babb27cfc495d37e8b --- /dev/null +++ b/demo_images/panorama/default_panorama.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96a263053d453cc15dcfbb7f10acf622a442cf00feea89f15cfd6c6e67d7c825 +size 669543 diff --git a/demo_images/satellite/sat_demo_1.csv b/demo_images/satellite/sat_demo_1.csv new file mode 100644 index 0000000000000000000000000000000000000000..b35ec500e7729c855c304a870a35ab640d15111b --- /dev/null +++ b/demo_images/satellite/sat_demo_1.csv @@ -0,0 +1,80 @@ +w,h,angle +144.08924981624142,415.97353283362116,-0.18336545311750235 +147.8063616554169,415.2841997740026,-0.1916271018999594 +151.5187165375632,414.5639745251622,-0.19958468279565406 +155.2264035423511,413.81399216729886,-0.2072383172873915 +158.92951174945165,413.03538778061204,-0.21458828129442548 +162.62813023853556,412.2292964453005,-0.2216349837996983 +166.32234808927382,411.3968532415638,-0.22837894637381767 +170.0122543813373,410.53919324960094,-0.23482078362743702 +173.69793819439678,409.6574515496111,-0.2409611846164222 +177.37948860812324,408.75276322179354,-0.24680089521144738 +181.05699470218752,407.82626334634733,-0.25234070143497356 +184.73054555626047,406.87908700347174,-0.2575814137620896 +188.40023025001304,405.912369273366,-0.26252385237404635 +192.06613786311598,404.92724523622894,-0.2671688333485118 +195.72835747524027,403.92484997225995,-0.2715171557677608 +199.38697816605676,402.9063185616585,-0.2755695897210954 +203.04208901523637,401.87278608462327,-0.2793268787804945 +206.69377909429483,400.82538756992574,-0.2828246694975033 +210.3421165035366,399.76512579572324,-0.2861740942171928 +213.98710300392804,398.69258518484435,-0.2893973330619872 +217.62872727016295,397.60826763423904,-0.2924947208957811 +221.26697797693444,396.5126750408567,-0.2954665770295263 +224.90184379893603,395.406309301647,-0.29831320384182997 +228.53331341086113,394.2896723135595,-0.30103488546093005 +232.16137548740303,393.16326597354373,-0.3036318865062072 +235.78601870325522,392.0275921785492,-0.3061044508865356 +239.40723173311102,390.88315282552554,-0.3084528006533492 +243.02500325166392,389.7304498114224,-0.310677064734146 +246.63932225659002,388.56998520925583,-0.3127616106049064 +250.250250337565,387.40230066389296,-0.3146706918791157 +253.85801311804832,386.2280272394269,-0.3163992058089867 +257.462858740972,385.0478082759211,-0.317946959754716 +261.0650353492678,383.8622871134386,-0.31931377695209934 +264.6647910858677,382.6721070920426,-0.3204994950300902 +268.26237409370333,381.4779115517963,-0.32150396469795806 +271.8580325157067,380.280343832763,-0.32232704860517963 +275.45201449480976,379.08004727500565,-0.3229686203742507 +279.0445681739442,377.8776652185876,-0.32342856380907803 +282.6359416960419,376.67384100357197,-0.3237067722792018 +286.22638320403473,375.4692179700221,-0.3238031482815028 +289.8161408408547,374.26443945800094,-0.3237176031796829 +293.40546274943335,373.06014880757175,-0.32345005712294383 +296.9945970727028,371.8569893587978,-0.32300043914420024 +300.5837919535948,370.65560445174236,-0.3223686874381936 +304.17329553504135,369.4566374264683,-0.3215547498199447 +307.7633559599741,368.2607316230391,-0.3205585843643651 +311.354221371325,367.0685303815179,-0.319380160226465 +314.94613991202596,365.8806770419677,-0.31801945864198616 +318.53935972500875,364.6978149444519,-0.31647639799992155 +322.1341289691096,363.5205877267635,-0.31472429375646727 +325.7307013782104,362.34974339236385,-0.31269725268400456 +329.3293444354928,361.1862873335892,-0.310385152116998 +332.930327706209,360.0312639194457,-0.3077878199926932 +336.5339207556115,358.8857175189398,-0.3049050767172916 +340.14039314895297,357.7506925010781,-0.30173674148711155 +343.7500144514854,356.62723323486694,-0.29828263932173654 +347.36305422846135,355.51638408931245,-0.2945426088126659 +350.9797820451333,354.4191894334213,-0.2905165105901754 +354.6004674667536,353.33669363619987,-0.28620423650766497 +358.2253800585747,352.26994106665444,-0.28160571954249913 +361.85478938584885,351.21997609379133,-0.27672094440825273 +365.4889650138288,350.18784308661725,-0.2715499588706621 +369.1281765077666,349.1745864141384,-0.2660928857535203 +372.7726934329148,348.181250445361,-0.2603499356189212 +376.4227853545259,347.20887954929174,-0.25432142009847475 +380.0787218378521,346.2585180949368,-0.24800776584554812 +383.740772448146,345.33121045130275,-0.241409529074176 +387.40920675066,344.428000987396,-0.2345411440931796 +391.08427457796313,343.5498854341411,-0.22748850551208588 +394.76610334860516,342.697557790409,-0.22028475763312177 +398.4547736038706,341.8715965096973,-0.21293052703920015 +402.15036578986684,341.07257981090635,-0.20542645848667948 +405.8529603527015,340.30108591293657,-0.1977732857614259 +409.5626377384817,339.5576930346886,-0.1899718362626802 +413.2794783933151,338.8429793950629,-0.18202303553559968 +417.0035627633091,338.1575232129599,-0.17392791173094835 +420.7349712945709,337.50190270728,-0.16568759996858143 +424.47378443320827,336.876696096924,-0.15730334658129785 +428.2200826253283,336.2824816007919,-0.14877651321174323 diff --git a/demo_images/satellite/sat_demo_1.png b/demo_images/satellite/sat_demo_1.png new file mode 100644 index 0000000000000000000000000000000000000000..981ebeb127959c032e19510b5dbac0e85d3d6fc5 --- /dev/null +++ b/demo_images/satellite/sat_demo_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa003952162f5421887d146501c0c289b663be299e778223f220576f9f034c2e +size 597902 diff --git a/demo_images/satellite/sat_demo_10.csv b/demo_images/satellite/sat_demo_10.csv new file mode 100644 index 0000000000000000000000000000000000000000..f45b6bd2259eb526638ebad426d70a7294162300 --- /dev/null +++ b/demo_images/satellite/sat_demo_10.csv @@ -0,0 +1,80 @@ +w,h,angle +251.77379078980815,506.48506699922376,-1.5962965135658407 +251.6946639863019,503.3827504363134,-1.5899114787791961 +251.6353784262066,500.2816325794658,-1.583711933304399 +251.59533840981896,497.18167827716996,-1.577698887107828 +251.573948237435,494.0828523779136,-1.5718732521378143 +251.5706122093514,490.9851197301854,-1.566235848000831 +251.58473462586434,487.8884451824738,-1.5607874075983847 +251.61571978727022,484.7927935832667,-1.5555285827043115 +251.66297199386548,481.69812978105307,-1.550459949460317 +251.72589554594646,478.6044186243208,-1.5455820137710325 +251.80389474380934,475.51162496155814,-1.5408952165813692 +251.8963738877509,472.419713641254,-1.5363999390220173 +252.00273727806714,469.3286495118962,-1.5320965074068411 +252.1223892150546,466.2383974219733,-1.5279851980735266 +252.25473399900963,463.1489222199737,-1.5240662420541675 +252.39917593022867,460.0601887543857,-1.5203398295682435 +252.55511930900803,456.97216187369764,-1.5168061143295195 +252.721968435644,453.8848064263978,-1.5134652176602763 +252.89912761043306,450.7980872609747,-1.510317232407784 +253.08600113367163,447.71196922591656,-1.5073622266576125 +253.28199330565585,444.62641716971166,-1.5046002472402225 +253.48650842668238,441.5413959408485,-1.5020313230289242 +253.69895079704742,438.4568703878154,-1.4996554680248488 +253.9187247170473,435.37280535910054,-1.4974726842290331 +254.1452344869785,432.28916570319234,-1.4954829642995082 +254.37788440713751,429.2059162685796,-1.4936862939927897 +254.6160787778204,426.12302190375,-1.492082654388442 +254.85922189932373,423.0404474571921,-1.4906720238983082 +255.10671807194393,419.95815777739466,-1.4894543800584925 +255.3579715959773,416.8761177128455,-1.4884297011049905 +255.61238677172005,413.794292112033,-1.487597967333263 +255.86936789946876,410.71264582344594,-1.4869591622421339 +256.1283192795198,407.6311436955723,-1.486513273461653 +256.3886452121694,404.54975057690046,-1.486260293465556 +256.649749997714,401.4684313159188,-1.4862002200688649 +256.9110379364502,398.387150761116,-1.486333056710555 +257.17191332867395,395.3058737609798,-1.4866588125202198 +257.43178047468194,392.22456516399893,-1.4871775021714635 +257.69004367477044,389.14318981866177,-1.487889145518694 +257.9461072292358,386.0617125734566,-1.4887937670194946 +258.1993754383744,382.9800982768717,-1.4898913949411847 +258.44925260248266,379.8983117773954,-1.4911820603518022 +258.69514302185684,376.8163179235161,-1.4926657958943985 +258.93645099679344,373.7340815637222,-1.494342634345531 +259.1725808275888,370.65156754650195,-1.4962126069562023 +259.40293681453926,367.5687407203439,-1.4982757415761983 +259.6269232579412,364.48556593373615,-1.500532060560794 +259.84394445809096,361.4020080351672,-1.5029815784604483 +260.05340471528507,358.3180318731254,-1.5056242994936517 +260.2547083298197,355.2336022960989,-1.5084602148028183 +260.44725960199133,352.1486841525763,-1.5114892994952824 +260.6304628320963,349.06324229104587,-1.5147115094698156 +260.803722320431,345.97724155999583,-1.5181267780307877 +260.96644236729185,342.89064680791483,-1.5217350122929325 +261.11802727297504,339.803422883291,-1.5255360893788739 +261.2578813377772,336.7155346346126,-1.529529852414944 +261.3854088619945,333.6269469103681,-1.5337161063284677 +261.5000141459234,330.53762455904587,-1.5380946134538527 +261.60110148986024,327.4475324291343,-1.5426650889538245 +261.68807519410143,324.3566353691216,-1.5474271960641446 +261.7603395589433,321.2648982274962,-1.5523805411715577 +261.8172988846823,318.1722858527465,-1.5575246687360145 +261.85835747161474,315.07876309336075,-1.5628590560695212 +261.88291962003694,311.9842947978273,-1.5683831079861963 +261.8903896302454,308.8888458146346,-1.5740961513394656 +261.8801718025364,305.792380992271,-1.579997429463762 +261.8516704372063,302.6948651792246,-1.586086096541233 +261.8042898345516,299.5962632239841,-1.5923612119146524 +261.7374342948685,296.49653997503765,-1.5988217343699878 +261.65050811845344,293.3956602808736,-1.605466516415342 +261.54291560560284,290.29358898998044,-1.6122942985829811 +261.4140610566131,287.1902909508463,-1.6193037037851263 +261.26334877178044,284.0857310119597,-1.6264932317546636 +261.09018305140137,280.9798740218088,-1.6338612536053803 +260.8939681957722,277.8726848288822,-1.6414060065462956 +260.67410850518934,274.76412828166815,-1.649125588788356 +260.43000827994916,271.65416922865495,-1.6570179546813917 +260.16107182034796,268.54277251833093,-1.6650809101219561 +259.8667034266822,265.4299029991846,-1.6733121082731401 diff --git a/demo_images/satellite/sat_demo_10.png b/demo_images/satellite/sat_demo_10.png new file mode 100644 index 0000000000000000000000000000000000000000..75e802b829c6db339bb2cc618aad325ed3ad11d1 --- /dev/null +++ b/demo_images/satellite/sat_demo_10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3f645e720cfa988da15dc83b1502cd89fffadd314a4875343edee2de33e2a56 +size 498840 diff --git a/demo_images/satellite/sat_demo_2.csv b/demo_images/satellite/sat_demo_2.csv new file mode 100644 index 0000000000000000000000000000000000000000..8188698efeaf0b3537332a97de508abdc2f92003 --- /dev/null +++ b/demo_images/satellite/sat_demo_2.csv @@ -0,0 +1,80 @@ +w,h,angle +189.29549276619161,359.8700402440079,-0.10010759826059114 +192.48106987521334,359.5500701924698,-0.11839126842127932 +195.65110782457958,359.1730020058246,-0.13604480037322997 +198.8061846033885,358.74110236184436,-0.1530476807607791 +201.9468782007374,358.2566379382992,-0.16938198305156543 +205.0737666057242,357.72187541296057,-0.18503223593183252 +208.18742780744643,357.1390814635991,-0.19998527007277683 +211.2884397950019,356.5105227679858,-0.2142300486006032 +214.3773805574882,355.8384660038919,-0.22775748627435674 +217.45482808400308,355.1251778490881,-0.24056026194474564 +220.52136036364416,354.3729249813454,-0.2526326283666673 +223.57755538550924,353.5839740784348,-0.26397022289708955 +226.62399113869574,352.7605918181271,-0.2745698820651502 +229.66124561230168,351.9050448781936,-0.28442946247271034 +232.6898967954246,351.0195999364051,-0.2935476699852364 +235.71052267716212,350.1065236705323,-0.3019238987255484 +238.723701246612,349.16808275834654,-0.3095580809856762 +241.73001049287186,348.20654387761846,-0.31645054883020274 +244.73002840503952,347.2241737061194,-0.3226019078885854 +247.72433297221258,346.22323892162,-0.32801292360093554 +250.71350218348863,345.20600620189134,-0.33268442001591747 +253.6981140279655,344.17474222470435,-0.3366171911082836 +256.67874649474084,343.13171366782996,-0.3398119245041953 +259.65597757291226,342.07918720903916,-0.3422691374558302 +262.63038525157754,341.0194295261029,-0.34398912489294425 +265.6025475198343,339.9547072967922,-0.3449719193916429 +268.57304236678016,338.88728719887797,-0.3452172629334773 +271.5424477815131,337.8194359101312,-0.34472459037841324 +274.5113417531305,336.75342010832287,-0.34349302463484427 +277.48030227073014,335.69150647122376,-0.3415213835780211 +280.44990732340966,334.63596167660506,-0.33880819883919727 +283.4207349002668,333.5890524022375,-0.3353517466522149 +286.3933629903993,332.5530453258923,-0.3311500910094189 +289.36836958290473,331.5302071253402,-0.3262011394219666 +292.34633266688076,330.52280447835227,-0.3205027116161005 +295.32783023142525,329.53310406269946,-0.3140526215012159 +298.31344026563573,328.56337255615273,-0.3068487727264839 +301.30374075860993,327.61587663648294,-0.2988892680861183 +304.29930969944553,326.6928829814612,-0.2901725329340693 +307.30072507724014,325.7966582688584,-0.28069745261963525 +310.3085648810917,324.92946917644554,-0.27046352375253535 +313.32340710009765,324.0935823819935,-0.2594710188366035 +316.3458297233557,323.2912645632733,-0.24772116348001227 +319.3764107399635,322.5247823980558,-0.2353277333488414 +322.41564618071,321.79606415310974,-0.22309389121465012 +325.46344388626784,321.1046094212929,-0.2113726960242702 +328.5194547765243,320.44885695341236,-0.20016862704024516 +331.5833286964775,319.82724106198975,-0.18948390072408167 +334.65471549112544,319.23819605954657,-0.17932001922478794 +337.733265005466,318.6801562586045,-0.1696778556090425 +340.8186270844973,318.1515559716849,-0.1605577362324779 +343.91045157321713,317.6508295113092,-0.1519595196667983 +347.0083883166238,317.17641118999927,-0.14388267172814762 +350.112087159715,316.72673532027636,-0.13632633626669216 +353.2211979474887,316.3002362146621,-0.12928940148386994 +356.3353705249432,315.8953481856781,-0.1227705616275724 +359.45425473707616,315.5105055458457,-0.1167683739912103 +362.5775004288858,315.1441426076866,-0.11128131120283047 +365.70475744537003,314.79469368372224,-0.1063078088369347 +368.8356756315268,314.46059308647415,-0.10184630841898709 +371.9699048323541,314.1402751284639,-0.09789529591799641 +375.10709489284994,313.83217412221296,-0.09445333583925615 +378.24689565801225,313.5347243802429,-0.09151910103847 +381.3889569728391,313.24636021507524,-0.08909139838032573 +384.53292868232853,312.96551593923164,-0.08716919036072429 +387.6784606314784,312.6906258652334,-0.08575161280267399 +390.82520266528684,312.42012430560214,-0.08483798872459422 +393.9728046287516,312.1524455728595,-0.08442783846231923 +397.1209163668709,311.88602397952684,-0.084520886108343 +400.2691877246426,311.61929383812594,-0.08511706231343225 +403.41726854706485,311.35068946117804,-0.08621650347147215 +406.56480867913547,311.0786451612048,-0.08781954729122714 +409.71145796585245,310.80159525072776,-0.08992672473337979 +412.8568662522138,310.5179740422684,-0.09253874827356537 +416.0006833832176,310.2262158483483,-0.09565649643178505 +419.1425592038617,309.92475498148895,-0.09928099449148606 +422.2821435591443,309.6120257542119,-0.1034133913181809 +425.4190862940631,309.28646247903873,-0.10805493217524006 +428.55303725361637,308.9464994684909,-0.11320692742940021 diff --git a/demo_images/satellite/sat_demo_2.png b/demo_images/satellite/sat_demo_2.png new file mode 100644 index 0000000000000000000000000000000000000000..bbbe7659699fad23e0fa63c1fbfb772a5b2f0c90 --- /dev/null +++ b/demo_images/satellite/sat_demo_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ce6c22d4a8d01bb67c5d0c550a1eb6a806df4a84456ba89de4113fd4073aae2 +size 527379 diff --git a/demo_images/satellite/sat_demo_3.csv b/demo_images/satellite/sat_demo_3.csv new file mode 100644 index 0000000000000000000000000000000000000000..deea8a33cf0932e60ff87be1f4f77fed6f0b94a5 --- /dev/null +++ b/demo_images/satellite/sat_demo_3.csv @@ -0,0 +1,80 @@ +w,h,angle +251.7589723122136,195.16413878960634,1.6219979451631072 +251.56249148498733,198.99817968609221,1.6180137193684725 +251.3812692462885,202.83336692447182,1.6141277273917483 +251.21493261836144,206.66967287080732,1.6103400200699827 +251.0631086234499,210.50706989116046,1.6066506326656056 +250.9254242837982,214.3455303515934,1.6030595859638175 +250.8015066216504,218.18502661816808,1.5995668873432942 +250.6909826592504,222.02553105694636,1.5961725318206614 +250.59347941884255,225.86701603399024,1.592876503067845 +250.50862392267075,229.70945391536173,1.5896787744008345 +250.436043192979,233.5528170671226,1.5865793097401246 +250.37536425201165,237.39707785533514,1.5835780645422959 +250.32621412201254,241.24220864606102,1.580674986701454 +250.28821982522587,245.08818180536238,1.5778700174215694 +250.26100838389567,248.9349696993011,1.5751630920583501 +250.2442068202661,252.7825446939392,1.572554140931219 +250.23744215658115,256.6308791553387,1.5700430901047937 +250.2403414150849,260.4799454495613,1.5676298621402578 +250.25253161802152,264.3297159426692,1.5653143768163955 +250.27363978763506,268.1801630007244,1.563096551820385 +250.30329294616942,272.0312589897886,1.5609763034082547 +250.3411181158689,275.88297627592397,1.5589535470358762 +250.38674231897758,279.73528722519245,1.5570281979595657 +250.43979257773938,283.58816420365594,1.5552001718072055 +250.49989591439845,287.44157957737644,1.55346938512006 +250.56667935119904,291.295505712416,1.5518357558653915 +250.63976991038496,295.1499149748364,1.5502992039195989 +250.71879461420036,299.00477973069974,1.5488596515232405 +250.80338048488954,302.86007234606797,1.5475170237073435 +250.8931545446964,306.71576518700306,1.5462712486911374 +250.98774381586483,310.5718306195667,1.54512225825205 +251.08677532063928,314.4282410098214,1.5440699880682236 +251.1898760812637,318.2849687238287,1.5431143780328778 +251.29667311998202,322.1419861276506,1.542255372541694 +251.4067934590384,325.9992655873491,1.5414929207531667 +251.5198641206771,329.85677946898636,1.540826976822106 +251.63551212714196,333.7145001386241,1.5402575001059209 +251.75336450067715,337.5723999623243,1.5397844553450823 +251.87304826352678,341.43045130614917,1.5394078128167052 +251.99419043793495,345.28862653616034,1.5391275484620561 +252.11641804614564,349.14689801841996,1.53894364398779 +252.23935811040297,353.00523811898995,1.5388560869413104 +252.362637652951,356.8636192039322,1.5388648707599366 +252.48588369603394,360.7220136393089,1.5389699947943631 +252.60872326189573,364.58039379118173,1.539171464305872 +252.7307833727805,368.4387320256129,1.5394692904379375 +252.85169105093232,372.2970007086642,1.5398634901616817 +252.97107331859527,376.15517220639765,1.5403540861954566 +253.08855719801343,380.0132188848752,1.5409411068982657 +253.2037697114309,383.8711131101589,1.5416245861371505 +253.31633788109173,387.72882724831055,1.542404563128172 +253.42588872924006,391.5863336653922,1.5432810822512328 +253.53204927811987,395.44360472746575,1.5442541928380367 +253.63444654997534,399.3006128005934,1.5453239489337136 +253.73270756705048,403.1573302508369,1.5464904090312441 +253.8264593515894,407.01372944425816,1.5477536357789905 +253.91532892583615,410.86978274691927,1.549113695660841 +253.9989433120349,414.72546252488223,1.5505706586489314 +254.07692953242963,418.5807411442089,1.5521245978284346 +254.14891460926447,422.43559097096124,1.5537755889945877 +254.21452556478346,426.2899843712012,1.5555237102213084 +254.27338942123072,430.1438937109908,1.5573690414014638 +254.32513320085033,433.9972913563921,1.5593116637583642 +254.36938392588632,437.85014967346683,1.5613516593282712 +254.40576861858284,441.7024410282771,1.5634891104138793 +254.43391430118393,445.5541377868849,1.5657240990082366 +254.4534479959337,449.4052123153521,1.568056706189198 +254.46399672507619,453.2556369797407,1.5704870114840745 +254.46518751085551,457.1053841461127,1.5730150922044037 +254.45664737551567,460.95442618052994,1.5756410227505517 +254.4380033413009,464.8027354490546,1.5783648738865514 +254.40888243045518,468.6502843177485,1.5811867119842797 +254.36891166522253,472.49704515267354,1.5841065982378377 +254.31771806784715,476.3429903198917,1.58712458784783 +254.254928660573,480.18809218546505,1.5902407291752831 +254.18017046564432,484.0323231154556,1.5934550628663036 +254.09307050530506,487.8756554759252,1.5967676209465504 +253.99325580179934,491.7180616329357,1.6001784258869012 +253.88035337737128,495.5595139525493,1.6036874896401225 diff --git a/demo_images/satellite/sat_demo_3.png b/demo_images/satellite/sat_demo_3.png new file mode 100644 index 0000000000000000000000000000000000000000..0cb999d3d61729b011b979d93980ee361d2cbaad --- /dev/null +++ b/demo_images/satellite/sat_demo_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97eb6f83034d6e84f66384d26b1be4b745a499dbe02e718ec0f15e2f194d56c8 +size 530207 diff --git a/demo_images/satellite/sat_demo_4.csv b/demo_images/satellite/sat_demo_4.csv new file mode 100644 index 0000000000000000000000000000000000000000..56692ae402a8142dba1580591ffe7e5eadecc902 --- /dev/null +++ b/demo_images/satellite/sat_demo_4.csv @@ -0,0 +1,80 @@ +w,h,angle +273.7159433789104,463.8503265404192,-1.729753543220757 +273.015297794522,459.4797514117093,-1.6999939357458085 +272.44820140061233,455.11482892499333,-1.671599444120782 +272.00720211182926,450.7547993589921,-1.644665789841402 +271.68484784282,446.3989029924256,-1.6192730160878097 +271.4736865082321,442.0463801040147,-1.5954861890252898 +271.36626602271326,437.69647097247986,-1.5733564823670119 +271.35513430091095,433.34841587654154,-1.5529225327988243 +271.4328392574728,429.0014550949204,-1.5342119598983448 +271.5919288070462,424.6548289063369,-1.5172429566105714 +271.8249508642788,420.3077775895114,-1.5020258723058235 +272.12445334381806,415.9595414231646,-1.4885647275092728 +272.4829841603115,411.6093606860168,-1.4768586157406312 +272.8930912284069,407.2564756567889,-1.4669029623587717 +273.34732246275155,402.9001266142011,-1.4586906222499822 +273.838225777993,398.539553836974,-1.4522128074592082 +274.35834908877894,394.17399760382824,-1.4474598425603604 +274.9002403097567,389.80269819348416,-1.4444217500134924 +275.456447355574,385.42489588466236,-1.4430886703700083 +276.01951814087835,381.03983095608334,-1.4434511233914653 +276.58200058031724,376.64674368646763,-1.4455001163755035 +277.1364425885382,372.2448743545358,-1.449227105620829 +277.67539208018883,367.8334632390082,-1.4546238163744258 +278.19139696991664,363.4117506186056,-1.4616819261087093 +278.6770051723691,358.9789767720483,-1.4703947398786308 +279.1247541134111,354.53439143412777,-1.4814490896953956 +279.52377594959347,350.0803143304569,-1.49658137087014 +279.8549214600764,345.6265311901413,-1.5161787268554454 +280.0978066399815,341.18394095105606,-1.5404378492382544 +280.2320474844308,336.763442551077,-1.5696082942920186 +280.2372599885459,332.3759349280793,-1.6039822340923162 +280.0930601474487,328.03231701993843,-1.6438785026112415 +279.7790639562608,323.74348776452985,-1.6896181869604185 +279.27488741010427,319.52034609972907,-1.7414888210596562 +278.56014650410066,315.37379096341124,-1.799694831239815 +277.6144572333719,311.31472129345224,-1.86429397158894 +276.4174355930398,307.35403602772715,-1.9339491173589447 +274.9534344175045,303.5014684239424,-2.0028420872324078 +273.2288567214531,299.76132545101376,-2.069094973792189 +271.2565167426674,296.1363363516171,-2.1325360100778896 +269.0492294305888,292.629230193298,-2.193094955858947 +266.6198097346583,289.2427360436015,-2.2507872992947413 +263.9810726043173,285.9795829700732,-2.305697099412993 +261.1458329890069,282.8425000402582,-2.3579603851393753 +258.1269058381682,279.8342163217019,-2.407750419416202 +254.9371061012425,276.95746088194954,-2.455265462468547 +251.58924872767093,274.21496278854653,-2.5007191781244833 +248.09614866689472,271.60945110903816,-2.5443335223062995 +244.470620868355,269.1436549109696,-2.5863337975292247 +240.725480281493,266.8203032618863,-2.6269455058043913 +236.87354185574983,264.64212522933343,-2.6663926439252514 +232.92762054056683,262.6118498808564,-2.7048934129145565 +228.90051162529787,260.73217888170706,-2.7422093174371143 +224.80273721450672,259.0026455247898,-2.7775488929155614 +220.64046028136931,257.41670733151483,-2.8108003398483836 +216.41935065406048,255.9671344761878,-2.841966384111274 +212.14507816075522,254.64669713311426,-2.8710666193026864 +207.82331262962862,253.44816547659997,-2.898133313814539 +203.45972388885542,252.36430968095036,-2.923207612849412 +199.05998176661097,251.3878999204713,-2.946336198128481 +194.62975609107,250.51170636946824,-2.967568417229739 +190.1747166904075,249.7284992022469,-2.9869538592529667 +185.70053339279863,249.03104859311293,-3.0045403314215906 +181.21287602641817,248.4121247163718,-3.0203721793059 +176.71741441944127,247.8644977463293,-3.0344888886738506 +172.21981840004287,247.38093785729106,-3.046923907039604 +167.7257577963979,246.95421522356273,-3.057703625810944 +163.24090243668135,246.57710001944977,-3.066846468091847 +158.77092214906827,246.24236241925797,-3.074362031680405 +154.3214867617336,245.94277259729282,-3.0802502409946895 +149.89826610285238,245.6711007278601,-3.084500465212189 +145.5069300005995,245.42011698526545,-3.087090562709449 +141.15314828315007,245.18259154381428,-3.0879858139963208 +136.84259077867898,244.9512945778125,-3.087137707006282 +132.58092731536124,244.71899626156556,-3.0844825402967952 +128.37382772137187,244.47846676937914,-3.079939812188778 +124.22696182488582,244.22247627555893,-3.0734103683439793 +120.1459994540781,243.94379495441046,-3.0647742886096796 +116.1366104371237,243.63519298023937,-3.0538885090378827 diff --git a/demo_images/satellite/sat_demo_4.png b/demo_images/satellite/sat_demo_4.png new file mode 100644 index 0000000000000000000000000000000000000000..26a45fe1c4c8533a00946bb92ee15677e9220936 --- /dev/null +++ b/demo_images/satellite/sat_demo_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99ed1df728aa77fca47a491c892026f19836669375546133d0971475f52a7449 +size 557501 diff --git a/demo_images/satellite/sat_demo_5.csv b/demo_images/satellite/sat_demo_5.csv new file mode 100644 index 0000000000000000000000000000000000000000..0673ec7f4187f671d4d699682c1abc5ca22a5fc8 --- /dev/null +++ b/demo_images/satellite/sat_demo_5.csv @@ -0,0 +1,80 @@ +w,h,angle +183.62571042553512,347.4802717007766,-0.03604428925814849 +186.84204810349965,347.36429086373096,-0.061070953663000485 +190.0150675218896,347.1702702714918,-0.08632959230199137 +193.14505915316505,346.8993860870459,-0.1117908181772409 +196.23231346978557,346.5528144733792,-0.13742372351956728 +199.27712094421125,346.13173159347826,-0.16319608281560502 +202.2797720489019,345.6373136103291,-0.18907458045272674 +205.2405572563173,345.0707366869184,-0.21502505929020307 +208.15976703891747,344.43317698623235,-0.24101278572457824 +211.03769186916207,343.725810671257,-0.2670027262254262 +213.87462221951108,342.949813904979,-0.2929598298947271 +216.67084856242442,342.1063628503845,-0.3188493113759756 +219.42666137036184,341.19663367045985,-0.34463692842950355 +222.14235111578336,340.22180252819163,-0.3702892486881543 +224.81820827114865,339.18304558656575,-0.3957739005042193 +227.45452330891766,338.08153900856865,-0.4210598033800554 +230.0515867015504,336.9184589571869,-0.4461173741891566 +232.6096889215064,335.6949815954064,-0.47091870621960363 +235.12912044124585,334.41228308621385,-0.4954377189595105 +237.6101717332284,333.0715395925954,-0.5196502774361933 +240.05313326991407,331.6739272775373,-0.5435342808036348 +242.45829552376267,330.220622304026,-0.5670697206803392 +244.82594896723398,328.7128008350477,-0.590238710463797 +247.15638407278792,327.1516390335889,-0.6130254874591606 +249.4498913128844,325.53831306263567,-0.6354163901425827 +251.7067611599833,323.8739990851746,-0.6573998132392561 +253.92728408654438,322.1598732641919,-0.6789661435220066 +256.1117505650275,320.39711176267366,-0.7001076793498743 +258.2604510678927,318.58689074360655,-0.7208185369708761 +260.3736760675996,316.7303863699767,-0.7410945465255664 +262.4517160366082,314.8287748047704,-0.7609331405357204 +264.49486144737847,312.88323221097414,-0.7803332374459837 +266.5034027723701,310.89493475157417,-0.799295122536397 +268.477630484043,308.8650585895566,-0.8178203282515935 +270.41783505485705,306.794779887908,-0.8359115157075723 +272.32430695727214,304.6852748096148,-0.8535723588545023 +274.19733666374805,302.53771951766294,-0.8708074325031351 +276.03721464674476,300.35329017503886,-0.8876221051681148 +277.84423137872204,298.1331629447291,-0.9040224374458761 +279.6186773321398,295.8785139897198,-0.9200150864373708 +281.36084297945786,293.5905194729973,-0.935607216542308 +283.07101879313615,291.27035555754793,-0.9508064103197408 +284.749495278449,288.9191983925752,-0.9656184661284246 +286.39657355363755,286.5382196696349,-0.9800461661953658 +288.0125805286927,284.12858024725665,-0.9940967121148474 +289.59784695623614,281.69143936999245,-1.0077782581042973 +291.15270358888904,279.227956282394,-1.0210990977778704 +292.6774811792732,276.7392902290133,-1.034067614494482 +294.1725104800101,274.22660045440205,-1.0466922369748426 +295.63812224372117,271.6910462031122,-1.058981399855239 +297.0746472230279,269.1337867196955,-1.0709435088318602 +298.482416170552,266.55598124870386,-1.0825869100473453 +299.86175983891485,263.9587890346891,-1.0939198633721243 +301.21300898073804,261.3433693222031,-1.1049505192441997 +302.53649434864303,258.71088135579765,-1.1156868987392934 +303.8325466952512,256.0624843800246,-1.1261368765602024 +305.1014967731843,253.39933763943577,-1.136308166649881 +306.3436753350637,250.72260037858314,-1.146208310149329 +307.5594131335111,248.03343184201842,-1.1558446654409986 +308.749040921148,245.33299127429348,-1.165224400035634 +309.9128894505957,242.6224379199602,-1.174354484078279 +311.0512894744759,239.90293102357043,-1.183241685267863 +312.16457174541,237.17562982967598,-1.191892564999114 +313.25306701601966,234.4416935828286,-1.200313475553733 +314.3171060389262,231.7022815275803,-1.2085105581801967 +315.35701956675143,228.9585529084829,-1.2164897419169813 +316.3731383521166,226.21166697008812,-1.2242567430241087 +317.36579314764333,223.46278295694788,-1.2318170649018885 +318.3353147059532,220.71306011361406,-1.2391759983829833 +319.2820337796676,217.96365768463846,-1.2463386222942683 +320.20628112140815,215.21573491457292,-1.2533098041926332 +321.1083874837964,212.47045104796933,-1.2600942011842169 +321.9886836194538,209.72896532937952,-1.2666962607431411 +322.84750028100177,206.99243700335526,-1.2731202214486563 +323.685168221062,204.26202531444846,-1.2793701135644695 +324.50201819225595,201.53888950721097,-1.285449759383144 +325.2983809472051,198.82418882619456,-1.2913627732614752 +326.0745872385311,196.11908251595122,-1.2971125612713275 +326.83096781885524,193.42472982103266,-1.3027023203868855 diff --git a/demo_images/satellite/sat_demo_5.png b/demo_images/satellite/sat_demo_5.png new file mode 100644 index 0000000000000000000000000000000000000000..7d08a4b08f874239bc5cc00559daa5125d88ce2c --- /dev/null +++ b/demo_images/satellite/sat_demo_5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8111683ae0a0dd9d1ff6787b82f3da3055ba272c7619d545c959eab8d5236ae7 +size 570369 diff --git a/demo_images/satellite/sat_demo_6.csv b/demo_images/satellite/sat_demo_6.csv new file mode 100644 index 0000000000000000000000000000000000000000..9b5041e2dcf62e44882cb2c9c13b5aac1b194ee0 --- /dev/null +++ b/demo_images/satellite/sat_demo_6.csv @@ -0,0 +1,80 @@ +w,h,angle +328.8626734632663,429.5073253394794,-2.1914327098340127 +327.1092330063393,427.05451567960483,-2.191182775930267 +325.35549460014425,424.5999927593531,-2.19087393885907 +323.60163501055655,422.14369684871616,-2.1905062706242027 +321.84783100345135,419.68556821768476,-2.1900798368477585 +320.09425934470426,417.2255471362508,-2.189594696835013 +318.3410968001905,414.7635738744055,-2.18905090364399 +316.58852013578513,412.2995887021401,-2.1884485041588717 +314.8367061173639,409.8335318894462,-2.1877875391681156 +313.0858315108019,407.3653437063152,-2.187068043445804 +311.33607308197423,404.89496442273804,-2.1862900458382737 +309.58760759675675,402.422334308707,-2.1854535693541206 +307.84061182102437,399.9473936342125,-2.1845586312592213 +306.09526252065257,397.4700826692466,-2.1836052431753137 +304.3517364615166,394.99034168380035,-2.182593411183798 +302.6102104094919,392.5081109478655,-2.1815231359332126 +300.8708611304538,390.0233307314332,-2.1803944127519634 +299.1338653902775,387.53594130449477,-2.1792072317650226 +297.3993999548383,385.04588293704165,-2.1779615780157138 +295.6676415900117,382.5530958990654,-2.176657431592084 +293.9387670616729,380.05752046055716,-2.17529476775814 +292.21295313569726,377.55909689150843,-2.1738735570902135 +290.4903765779601,375.05776546191083,-2.1723937656178425 +288.77121415433663,372.55346644175523,-2.1708553549707044 +287.0556426307024,370.0461401010335,-2.1692582825298676 +285.3438387729326,367.53572670973693,-2.1676025015849554 +283.6359793469026,365.02216653785683,-2.1658879614967965 +281.93224111848764,362.5053998553845,-2.1641146078653923 +280.2328008535632,359.98536693231165,-2.1622823827038684 +278.5378353180045,357.4620080386294,-2.1603912246182007 +276.8475212776868,354.93526344432905,-2.158441068992894 +275.1620354984856,352.4050734194024,-2.1564318481827245 +273.4815547462762,349.87137823384046,-2.154363491711118 +271.8062557869336,347.33411815763486,-2.1522359264738604 +270.1363153863335,344.79323346077666,-2.1500490769509604 +268.47191031035123,342.24866441325787,-2.1478028654233823 +266.81321732486185,339.7003512850692,-2.1454972121978706 +265.1604131957408,337.1482343462024,-2.143132035837816 +263.5136746888635,334.5922538666489,-2.1407072534018092 +261.87317857010515,332.0323501164,-2.1382227806886784 +260.23910160534115,329.46846336544695,-2.135678532490438 +258.6116205604468,326.90053388378146,-2.133074422851944 +256.9909122012975,324.3285019413946,-2.130410365338866 +255.37715329376837,321.752307808278,-2.1276862733123156 +253.77052060373498,319.171891754423,-2.124902060212089 +252.17119089707248,316.5871940498209,-2.1220576398470454 +250.57934093965628,313.9981549644631,-2.1191529266938347 +248.9951474973617,311.4047147683411,-2.1161878362032907 +247.41878733606399,308.80681373144625,-2.1131622851149703 +245.85043722163857,306.20439212376994,-2.1100761917800255 +244.29027391996073,303.5973902153035,-2.1069294764918682 +242.73847419690574,300.9857482760384,-2.103722061825227 +241.19521481834903,298.369406575966,-2.1004538729835467 +239.66067255016583,295.74830538507774,-2.0971248381543104 +238.13502415823154,293.12238497336494,-2.093734888872963 +236.61844640842145,290.491585610819,-2.0902839603947236 +235.11111606661086,287.8558475674314,-2.0867719920747274 +233.61320989867517,285.2151111131934,-2.083198927756399 +232.1249046704896,282.56931651809657,-2.0795647161674706 +230.64637714792957,279.91840405213213,-2.0758693113245084 +229.17780409687032,277.2623139852915,-2.0721126729447104 +227.71936228318725,274.6009865875662,-2.068294766865717 +226.27122847275567,271.9343621289475,-2.0644155654728453 +224.83357943145086,269.26238087942676,-2.060475048133598 +223.4065919251482,266.5849831089955,-2.0564732016392684 +221.99044271972298,263.9021090876451,-2.052410020653393 +220.5853085810505,261.21369908536684,-2.0482855081667912 +219.19136627500617,258.51969337215223,-2.044099675958819 +217.80879256746525,255.82003221799258,-2.03985254506447 +216.43776422430304,253.1146558928793,-2.035544146247035 +215.07845801139499,250.40350466680385,-2.0311745204760094 +213.73105069461627,247.68651880975756,-2.0267437194092874 +212.3957190398423,244.9636385917318,-2.022251805879992 +211.07263981294838,242.23480428271802,-2.017698854386585 +209.7619897798098,239.49995615270757,-2.013084951586223 +208.46394570630196,236.7590344716919,-2.0084101967906607 +207.17868435830013,234.01197950966235,-2.003674702463851 +205.90638250167962,231.2587315366103,-1.9988785947207854 +204.64721690231585,228.4992308225272,-1.9940220138269091 diff --git a/demo_images/satellite/sat_demo_6.png b/demo_images/satellite/sat_demo_6.png new file mode 100644 index 0000000000000000000000000000000000000000..a1b46f98ddf8f1bf4ed5dc628f5eaedb2bdc0038 --- /dev/null +++ b/demo_images/satellite/sat_demo_6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dacc77546a0392ae12514d6fc434f2f9731afd2e3b9c1a8b44766273cd75bd8 +size 611951 diff --git a/demo_images/satellite/sat_demo_7.csv b/demo_images/satellite/sat_demo_7.csv new file mode 100644 index 0000000000000000000000000000000000000000..05f7ba3ac94a215d37685d1163b5b2118dfc0d3b --- /dev/null +++ b/demo_images/satellite/sat_demo_7.csv @@ -0,0 +1,80 @@ +w,h,angle +184.49767298292676,408.9779150298452,0.020775574359455106 +188.2253441968674,409.05537068457784,0.013781120921666676 +191.953382291264,409.10675048107004,0.0070148796926462875 +195.68177413092286,409.13290513028437,0.00047743108067126617 +199.41050658064987,409.13468534318264,-0.005830763378367393 +203.13956650525128,409.11294183072744,-0.011909350946091957 +206.868940769533,409.0685253038805,-0.017758077852758118 +210.59861623830105,409.002286473604,-0.023376779746977244 +214.32857977636166,408.9150760508604,-0.028765372448362705 +218.05881824852085,408.8077447466116,-0.033923843023596516 +221.7893185195845,408.6811432718197,-0.03885224120533834 +225.52006745435892,408.536122337447,-0.04355067116709803 +229.25105191765,408.37353265445546,-0.04801928366457387 +232.9822587742639,408.19422493380745,-0.05225826855344862 +236.71367488900663,407.99904988646506,-0.05626784768749898 +240.4452871266843,407.7888582233903,-0.060048268201829895 +244.17708235210296,407.56450065554554,-0.06359979618290851 +247.9090474300686,407.3268278938927,-0.06692271072464116 +251.64116922538742,407.07669064939415,-0.07001729837072646 +255.37343460286533,406.8149396330118,-0.07288384793928385 +259.10583042730855,406.54242555570795,-0.07552264572828384 +262.838343563523,406.2599991284448,-0.07793397109684605 +266.5709608763148,405.9685110621843,-0.0801180924182596 +270.30366923048996,405.66881206788867,-0.08207526340094821 +274.0364554908547,405.36175285652007,-0.08380571977172208 +277.769306522215,405.04818413904087,-0.08530967631761177 +281.5022091893768,404.7289566264129,-0.08658732428063265 +285.23515035714627,404.4049210295984,-0.0876388291027453 +288.9681168903295,404.0769280595596,-0.08846432851592767 +292.7010956537326,403.74582842725863,-0.08906393097467219 +296.4340735121614,403.41247284365744,-0.08943771442718136 +300.16703733042226,403.07771201971855,-0.08958572542465881 +303.89997397332104,402.74239666640375,-0.08950797856441242 +307.63287030566374,402.4073774946753,-0.08920445626837932 +311.3657131922567,402.0735052154954,-0.08867510889485102 +315.0984894979058,401.74163053982636,-0.08791985518471102 +318.8311860874171,401.41260417862986,-0.08693858304199563 +322.56378982559664,401.08727684286845,-0.08573115065250769 +326.2962875772506,400.7664992435042,-0.08429738794026588 +330.028666207185,400.45112209149914,-0.0826370983662422 +333.76091258020585,400.14199609781565,-0.08075006107275078 +337.4930135611192,399.83997197341546,-0.0786360333759069 +341.2249560147312,399.5459004292612,-0.07629475361304142 +344.9567268058479,399.26063217631474,-0.07372594434632733 +348.6883127992753,398.9850179255383,-0.07092931593045933 +352.4197008598194,398.7199083878939,-0.06790457044702582 +356.1508778522865,398.4661542743439,-0.06465140601173379 +359.88183064148245,398.22460629585026,-0.06116952145769757 +363.61254609221334,397.99611516337524,-0.0574586214004325 +367.34301106928535,397.781531587881,-0.05351842168671587 +371.0732124375044,397.5817062803296,-0.04934865523087928 +374.8031370616766,397.39748995168327,-0.04494911444206356 +378.53277178890556,397.2297331784064,-0.04032048071587482 +382.2621030284463,397.0792832103393,-0.035463507998669305 +385.99111673394367,396.9469838357551,-0.030378105522839205 +389.7197988377868,396.83367868142705,-0.02506419032707613 +393.4481352723643,396.74021137412865,-0.019521738727584188 +397.17611197006494,396.6674255406333,-0.013750794674067956 +400.903714863277,396.61616480771403,-0.007751478447163058 +404.6309298843899,396.58727280214487,-0.0015239956834812384 +408.35774296579183,396.58159315069855,0.004931353295074456 +412.08414003987184,396.59996948014884,0.011614163866589478 +415.81010703901836,396.64324541726904,0.01852391724046639 +419.53562989562033,396.71226458883234,0.02565997036501671 +423.2606945420663,396.80787062161244,0.033021545648096305 +426.9852869107451,396.9309071423825,0.0406077205341363 +430.7093929340456,397.082217777916,0.048417416982965006 +434.43299854435617,397.2626461549862,0.05644939090407887 +438.1560896740658,397.47303590036665,0.06470222160335283 +441.878652255563,397.71423064083046,0.07317430130654617 +445.6006722212367,397.98707400315135,0.08186382482676367 +449.3221355034756,398.29240961410255,0.09076877945160382 +453.0430280346683,398.6310811004574,0.09988693512752701 +456.76333574720354,399.0039320889892,0.10921583502568676 +460.48304457347,399.4118062064715,0.11875278657660139 +464.20214044585657,399.85554707967754,0.12849485306570138 +467.9206092967518,400.3359983353808,0.13843884588272523 +471.6384370585445,400.8540036003547,0.1485813175222422 +475.35560966362334,401.41040650137245,0.15891855543207917 diff --git a/demo_images/satellite/sat_demo_7.png b/demo_images/satellite/sat_demo_7.png new file mode 100644 index 0000000000000000000000000000000000000000..1a991b102f125e474931fb152cbb268cd0a37667 --- /dev/null +++ b/demo_images/satellite/sat_demo_7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afbfcdaafa4ab7d531fe4b23759c69d91756fcbbc8c41e1688eaf78f31f24b6a +size 591762 diff --git a/demo_images/satellite/sat_demo_8.csv b/demo_images/satellite/sat_demo_8.csv new file mode 100644 index 0000000000000000000000000000000000000000..d9d7bbedfa8e684ac4982e53e7e430cac7f66dcb --- /dev/null +++ b/demo_images/satellite/sat_demo_8.csv @@ -0,0 +1,80 @@ +w,h,angle +241.19971665249935,468.0632949352942,-1.474988639273778 +241.50705663965928,464.8652317146466,-1.487691839382883 +241.77327873644782,461.6691475097032,-1.500075543548405 +241.99954928606053,458.47499173217165,-1.5121328179333111 +242.18703463169254,455.2827137937588,-1.523857377364663 +242.33690111653934,452.09226310617197,-1.5352435618247704 +242.4503150837963,448.9035890811185,-1.5462863101308753 +242.5284428766587,445.71664113030533,-1.5569811312874826 +242.57245083832186,442.5313686654401,-1.567324073978034 +242.58350531198127,439.34772109822956,-1.5773116946387438 +242.56277264083207,436.16564784038104,-1.5869410245297808 +242.51141916806986,432.985098303602,-1.596209536189494 +242.43061123688977,429.80602189959933,-1.6051151096206442 +242.32151519048733,426.62836804008043,-1.6136559985273515 +242.18529737205782,423.4520861367524,-1.6218307968824481 +242.02312412479657,420.2771256013227,-1.6296384060735742 +241.83616179189903,417.1034358454982,-1.637078002841616 +241.6255767165604,413.9309662809863,-1.644149008193074 +241.3925352419761,410.75966631949416,-1.6508510574398088 +241.13820371134156,407.5894853727291,-1.6571839714903742 +240.86374846785196,404.42037285239803,-1.6631477294930437 +240.57033585470288,401.2522781702084,-1.6687424429089956 +240.25913221508952,398.0851507378675,-1.6739683310722122 +239.93130389220713,394.9189399670822,-1.6788256982777832 +239.58801722925125,391.75359526956004,-1.6833149124243454 +239.2304385694173,388.58906605700827,-1.687436385223653 +238.8597342559004,385.4253017411338,-1.6911905539811118 +238.47707063189603,382.2622517336439,-1.694577864943033 +238.0836140405996,379.0998654462461,-1.6975987581994 +237.68053082520635,375.93809229064726,-1.7002536541273388 +237.2689873289116,372.7768816785546,-1.7025429413577484 +236.85014989491083,369.61618302167557,-1.7044669662458152 +236.42518486639935,366.45594573171724,-1.7060260238259988 +235.99525858657245,363.29611922038674,-1.7072203502332586 +235.5615373986255,360.1366528993914,-1.7080501165739366 +235.12518764575404,356.9774961804385,-1.7085154242324854 +234.68737567115312,353.818598475235,-1.7086163016021592 +234.24926781801824,350.65990919548824,-1.708352702233986 +233.81203042954488,347.50137775290557,-1.7077245043992952 +233.3768298489281,344.34295355919403,-1.7067315120665723 +232.9448324193635,341.1845860260608,-1.7053734572986576 +232.51720448404637,338.0262245652132,-1.7036500040774858 +232.095112386172,334.86781858835843,-1.7015607535702126 +231.6797224689358,331.70931750720365,-1.6991052508526312 +231.27220107553308,328.5506707334562,-1.6962829931084582 +230.87371454915922,325.391827678823,-1.693093439326745 +230.48542923300963,322.23273775501156,-1.6895360215195554 +230.10851147027955,319.0733503737289,-1.6856101574836344 +229.74412760416442,315.91361494668234,-1.6813152651290817 +229.39344397785956,312.75348088557905,-1.676650778395486 +229.0576269345603,309.5928976021262,-1.6716161647731973 +228.737842817462,306.4318145080311,-1.6662109444420825 +228.43525796976004,303.2701810150009,-1.6604347110329216 +228.15103873464977,300.1079465347428,-1.6542871540077424 +227.88635145532652,296.945060478964,-1.6477680826440158 +227.6423624749857,293.78147225937175,-1.6408774515940892 +227.42023813682255,290.61713128767315,-1.6336153879742776 +227.22114478403253,287.4519869755756,-1.625982219920496 +227.04624875981096,284.2859887347862,-1.6179785065238694 +226.89671640735315,281.11908597701216,-1.6096050690368824 +226.7737140698545,277.95122811396067,-1.6008630232127938 +226.67840809051037,274.78236455733895,-1.5917538126114001 +226.61196481251608,271.6124447188543,-1.5822792426722647 +226.57555057906694,268.4414180102138,-1.5724415153227635 +226.57033173335844,265.26923384312477,-1.5622432638524655 +226.5974746185858,262.0958416292943,-1.5516875877482663 +226.65814557794442,258.9211907804297,-1.5407780871491954 +226.75351095462975,255.74523070823815,-1.5295188965421995 +226.88473709183694,252.56791082442686,-1.5179147172858314 +227.05299033276145,249.38918054070297,-1.5059708485180126 +227.2594370205987,246.20898926877385,-1.4936932159743397 +227.505243498544,243.02728642034657,-1.481088398221927 +227.7915761097926,239.84402140712842,-1.4681636497978898 +228.11960119753994,236.65914364082653,-1.454926920733424 +228.49048510498136,233.47260253314812,-1.4413868719463354 +228.90539417531224,230.28434749580055,-1.4275528859976083 +229.36549475172794,227.0943279404908,-1.4134350727317055 +229.8719531774237,223.90249327892627,-1.3990442693574123 +230.42593579559505,220.7087929228141,-1.3843920345767522 diff --git a/demo_images/satellite/sat_demo_8.png b/demo_images/satellite/sat_demo_8.png new file mode 100644 index 0000000000000000000000000000000000000000..7978ad47e4561a8b7d44d373e759d452fb67e7f1 --- /dev/null +++ b/demo_images/satellite/sat_demo_8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b0c766b9378e36a73b26c3aeebc1d95a0fe8456422c7bea3a94fd3d31475964 +size 506297 diff --git a/demo_images/satellite/sat_demo_9.csv b/demo_images/satellite/sat_demo_9.csv new file mode 100644 index 0000000000000000000000000000000000000000..03310db379df2acef7b024b81ae1520a29804154 --- /dev/null +++ b/demo_images/satellite/sat_demo_9.csv @@ -0,0 +1,80 @@ +w,h,angle +360.53626787533517,450.13784971908416,-1.9618677495104848 +359.4413117302415,447.48217399611497,-1.9625105066862283 +358.3433916612801,444.8241589988592,-1.962995351015549 +357.2429835055221,442.16377670276506,-1.963322876763445 +356.1405631000379,439.5009990832798,-1.9634936249661643 +355.0366062818989,436.8357981158516,-1.9635080837568193 +353.93158888817567,434.16814577592834,-1.9633666887617027 +352.8259867559391,431.49801403895736,-1.9630698235676676 +351.72027572226017,428.82537488038696,-1.9626178202571811 +350.6149316242096,426.15020027566453,-1.9620109600113431 +349.5104302988581,423.472462200238,-1.9612494737796227 +348.40724758327707,420.79213262955534,-1.9603335430173694 +347.30585931453686,418.1091835390641,-1.9592633004894173 +346.20674132970856,415.4235869042122,-1.9580388311431178 +345.11036946586296,412.7353147004474,-1.956660173049801 +344.01721956007117,410.04433890321764,-1.9551273184179183 +342.92776744940375,407.35063148797053,-1.953440214678094 +341.8424889709317,404.6541644301539,-1.9515987656443952 +340.7618599617258,401.95490970521564,-1.9496028327525423 +339.68635625885713,399.25283928860347,-1.9474522363798275 +338.6164536993962,396.5479251557652,-1.9451467572479904 +337.5526281204141,393.8401392821485,-1.942686137915392 +336.4953553589818,391.1294536432015,-1.940070084359534 +335.44511125216985,388.4158402143716,-1.9372982676560844 +334.40237163704944,385.6992709711068,-1.934370325758447 +333.3676123506914,382.9797178888551,-1.9312858653817213 +332.34130923016636,380.25715294306383,-1.9280444639970746 +331.3239381125452,377.53154810918096,-1.9246456719407272 +330.31597483489924,374.8028753626547,-1.9210890146438309 +329.3178952342989,372.0711066789324,-1.917373994986734 +328.330175147815,369.3362140334618,-1.913500095785429 +327.3532904125187,366.59816940169094,-1.9094667824139668 +326.3877168654808,363.8569447590677,-1.9052735055690015 +325.433930343772,361.1125120810394,-1.9009197041817378 +324.49240668446316,358.3648433430542,-1.8964048084829768 +323.5636217246255,355.6139105205601,-1.8917282432263571 +322.6480513013295,352.85968558900447,-1.8868894310743742 +321.7461712516461,350.10214052383526,-1.8818877961527112 +320.8584574126464,347.34124730050036,-1.8767227677768794 +319.98538562140095,344.5769778944474,-1.8713937843547293 +319.12743171498084,341.8093042811243,-1.8659002974692989 +318.28507153045683,339.03819843597887,-1.8602417761436725 +317.45878090489975,336.2636323344588,-1.8544177112908145 +316.6490356753807,333.48557795201197,-1.8484276203495262 +315.85631167897026,330.70400726408616,-1.8422710521061294 +315.0810847527394,327.9188922461291,-1.8359475917025043 +314.32383073375905,325.1302048735887,-1.8294568658275472 +313.5850254591,322.3379171219127,-1.822798548089236 +312.8651447658332,319.54200096654887,-1.8159723645630368 +312.16466449102944,316.74242838294504,-1.8089780995099263 +311.4840604717595,313.93917134654896,-1.8018156012564848 +310.82380854509444,311.1322018328085,-1.7944847882278616 +310.18438454810496,308.3214918171715,-1.7869856551207357 +309.5662643178621,305.50701327508557,-1.7793182792041813 +308.9699236914366,302.68873818199864,-1.7714828267309497 +308.39583850589935,299.8666385133585,-1.7634795594422066 +307.8444845983212,297.04068624461286,-1.7553088411441844 +307.31633780577306,294.21085335120966,-1.746971144334373 +306.81187396532573,291.3771118085966,-1.7384670568506582 +306.33156891405014,288.5394335922215,-1.729797288516192 +305.8758984890171,285.6977906775321,-1.7209626777480547 +305.4453385272976,282.8521550399763,-1.7119641980970948 +305.0403648659624,280.00249865500183,-1.7028029646820249 +304.66145334208227,277.1487934980565,-1.6934802404800888 +304.3090797927283,274.29101154458806,-1.6839974424330215 +303.9837200549712,271.4291247700444,-1.6743561473247126 +303.68584996588186,268.56310514987325,-1.6645580973868788 +303.4159453625312,265.6929246595225,-1.6546052055845435 +303.17448208199005,262.81855527443975,-1.6444995605341413 +302.96193596132923,259.939968970073,-1.6342434310046112 +302.7787828376197,257.05713772186994,-1.6238392699519926 +302.6254985479323,254.1700335052784,-1.613289718037062 +302.50255892933785,251.2786282957462,-1.6025976065767873 +302.4104398189072,248.38289406872101,-1.5917659598805842 +302.34961705371126,245.48280279965076,-1.5807979969241417 +302.320566470821,242.57832646398325,-1.569697132315761 +302.3237639073071,239.66943703716618,-1.558466976512026 +302.35968520024045,236.7561064946474,-1.5471113352446915 +302.4288061866921,233.83830681187476,-1.5356342081234624 diff --git a/demo_images/satellite/sat_demo_9.png b/demo_images/satellite/sat_demo_9.png new file mode 100644 index 0000000000000000000000000000000000000000..c238d5ca657ec16bee5cdbcbebfe82dcb7df5d98 --- /dev/null +++ b/demo_images/satellite/sat_demo_9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b2b87d31e86001a8dbec90a72c902eb2aaf168b703603a6b14129d9d59f6cfa +size 554771 diff --git a/demo_images/satellite/satellite_47.564866521505394_-122.33116389229833.png b/demo_images/satellite/satellite_47.564866521505394_-122.33116389229833.png new file mode 100644 index 0000000000000000000000000000000000000000..c0de901984783c837dc61a88463251e53e538526 --- /dev/null +++ b/demo_images/satellite/satellite_47.564866521505394_-122.33116389229833.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d9acb5afc3bea342cb441dba57fa1aa888b8dea403b212cdf32a049ceb12975 +size 301151 diff --git a/demo_images/satellite/satellite_47.57109989810551_-122.30149079418568.png b/demo_images/satellite/satellite_47.57109989810551_-122.30149079418568.png new file mode 100644 index 0000000000000000000000000000000000000000..151b3f0f1e8563f0a59faef56acea06f140ae1d8 --- /dev/null +++ b/demo_images/satellite/satellite_47.57109989810551_-122.30149079418568.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51baa9800c3fd8481a66f6ea17552784ab119fdb3e7d7de4c3d83cc041e5cff1 +size 313643 diff --git a/demo_images/satellite/satellite_47.57372447772661_-122.30489590380517.png b/demo_images/satellite/satellite_47.57372447772661_-122.30489590380517.png new file mode 100644 index 0000000000000000000000000000000000000000..80f4597bc8e3073263e9c9256d7cb6c4afeed73d --- /dev/null +++ b/demo_images/satellite/satellite_47.57372447772661_-122.30489590380517.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66942d1c25661007d34f336cc736776a8a0b752986cf9a4fae433ebf1040177c +size 292268 diff --git a/demo_images/satellite/satellite_47.574380622631885_-122.32192145190258.png b/demo_images/satellite/satellite_47.574380622631885_-122.32192145190258.png new file mode 100644 index 0000000000000000000000000000000000000000..09b9711ff5cc093e3d8205811570ed938e5850dc --- /dev/null +++ b/demo_images/satellite/satellite_47.574380622631885_-122.32192145190258.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70bb7c18780ca25aa38f903bbd72f87de91b2b51e3cb76e601ab79e77b8bcae4 +size 322604 diff --git a/demo_images/satellite/satellite_47.57503676753716_-122.33456900191781.png b/demo_images/satellite/satellite_47.57503676753716_-122.33456900191781.png new file mode 100644 index 0000000000000000000000000000000000000000..160342f08c4700bd1b7be03c143de2cccf724b02 --- /dev/null +++ b/demo_images/satellite/satellite_47.57503676753716_-122.33456900191781.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d58ac4a4eec272869f4859988cfd60addb2719f6f868c5200ab41ac04ca289e7 +size 289156 diff --git a/demo_images/satellite/satellite_47.576020984895074_-122.31657056535768.png b/demo_images/satellite/satellite_47.576020984895074_-122.31657056535768.png new file mode 100644 index 0000000000000000000000000000000000000000..77da1b95be809d6d14843c2875a4203795656892 --- /dev/null +++ b/demo_images/satellite/satellite_47.576020984895074_-122.31657056535768.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:557b7be4186ab36eff4b7f1d650c7027c2763a1b2620e1d4c21ce5bf66124139 +size 263395 diff --git a/demo_images/satellite/satellite_47.5779894196109_-122.31608412112634.png b/demo_images/satellite/satellite_47.5779894196109_-122.31608412112634.png new file mode 100644 index 0000000000000000000000000000000000000000..59ca2c664c61ac82077c35f6c840cabb22ff23d5 --- /dev/null +++ b/demo_images/satellite/satellite_47.5779894196109_-122.31608412112634.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbf1f99a1d2ef7c54dbf975d798a2d6998ed3697e0a5996dc33bc2b358aa5edb +size 291093 diff --git a/demo_images/satellite/satellite_47.58356665130574_-122.2873839114764.png b/demo_images/satellite/satellite_47.58356665130574_-122.2873839114764.png new file mode 100644 index 0000000000000000000000000000000000000000..7dc1c2ef7c5300a3b4993aa70766ba88fe23320b --- /dev/null +++ b/demo_images/satellite/satellite_47.58356665130574_-122.2873839114764.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dafd791532f112df08e5ff8dcebed1af164746c293f2f00e97ffb3e3c717c6ac +size 173458 diff --git a/demo_images/satellite/satellite_47.59012810035849_-122.32240789613394.png b/demo_images/satellite/satellite_47.59012810035849_-122.32240789613394.png new file mode 100644 index 0000000000000000000000000000000000000000..e94ccbe4895cb9b1734313585174e073298e1163 --- /dev/null +++ b/demo_images/satellite/satellite_47.59012810035849_-122.32240789613394.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a815fefff411ccb948828c3c34eac35420879bf83fc0821c0fd820dca4f1d107 +size 278646 diff --git a/demo_images/satellite/satellite_47.59111231771641_-122.28981613263318.png b/demo_images/satellite/satellite_47.59111231771641_-122.28981613263318.png new file mode 100644 index 0000000000000000000000000000000000000000..8fded7b6b3e03348562b0ada0305582cc13a76d3 --- /dev/null +++ b/demo_images/satellite/satellite_47.59111231771641_-122.28981613263318.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:043dbcda120490e09464576d434466d315d6d31db91b893c53d3a71ed1a5e5be +size 278128 diff --git a/demo_images/satellite/satellite_47.59176846262168_-122.29419413071537.png b/demo_images/satellite/satellite_47.59176846262168_-122.29419413071537.png new file mode 100644 index 0000000000000000000000000000000000000000..b83f7f8c15cc70168ad7922ebfbe67525dc438a4 --- /dev/null +++ b/demo_images/satellite/satellite_47.59176846262168_-122.29419413071537.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d7f71536a248d1b0907e82bae82b5c1ace2a42d31e58f4a5e73119fbd35aa49 +size 319679 diff --git a/demo_images/satellite/satellite_47.595377259600696_-122.32435367305936.png b/demo_images/satellite/satellite_47.595377259600696_-122.32435367305936.png new file mode 100644 index 0000000000000000000000000000000000000000..54598ff2c334b8f929ac614539cf9a21f8e91b48 --- /dev/null +++ b/demo_images/satellite/satellite_47.595377259600696_-122.32435367305936.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57ac87f3910d22c4be9b9970f4dc4edbc73d2081eb1ac18d223dcadf4c0764fc +size 286491 diff --git a/demo_images/satellite/satellite_47.59767376676916_-122.33651477884322.png b/demo_images/satellite/satellite_47.59767376676916_-122.33651477884322.png new file mode 100644 index 0000000000000000000000000000000000000000..58f05f7fa07a045abe0da6cb999e91596a5b604d --- /dev/null +++ b/demo_images/satellite/satellite_47.59767376676916_-122.33651477884322.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b05c79fd28421ce448411b539473313c8e12787c7bbfda6c795103ef4e6a68c +size 286933 diff --git a/demo_images/satellite/satellite_47.6006264188429_-122.29419413071537.png b/demo_images/satellite/satellite_47.6006264188429_-122.29419413071537.png new file mode 100644 index 0000000000000000000000000000000000000000..7387394f56f2a9445f1e5930add8d236cb1546ca --- /dev/null +++ b/demo_images/satellite/satellite_47.6006264188429_-122.29419413071537.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09e7a126c03941cfcce1fd34ab2500f4ead8cee71e5379e42f5f52ec5dd7d0af +size 279589 diff --git a/demo_images/satellite/satellite_47.60390714336928_-122.31559767689498.png b/demo_images/satellite/satellite_47.60390714336928_-122.31559767689498.png new file mode 100644 index 0000000000000000000000000000000000000000..6d48724c80cf9c7b3db80b4d723f3cd7e9cb6b4b --- /dev/null +++ b/demo_images/satellite/satellite_47.60390714336928_-122.31559767689498.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98cf1a2bd182face991732f8787fd69824663f98a5513875dbeabc6de29a0911 +size 237631 diff --git a/demo_images/satellite/satellite_47.616373896569506_-122.31024679035008.png b/demo_images/satellite/satellite_47.616373896569506_-122.31024679035008.png new file mode 100644 index 0000000000000000000000000000000000000000..814bdbb3ee9619aafa6cd8275a3463059e606235 --- /dev/null +++ b/demo_images/satellite/satellite_47.616373896569506_-122.31024679035008.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:772ebb7b67a2a5832680d0c4e7c8e7a1a54a018eae3f93d46f6b14796d35938e +size 290355 diff --git a/demo_images/satellite/satellite_47.616373896569506_-122.3131654557382.png b/demo_images/satellite/satellite_47.616373896569506_-122.3131654557382.png new file mode 100644 index 0000000000000000000000000000000000000000..84b4fe90cfe3800807c24f7958b0c0bcbe52643e --- /dev/null +++ b/demo_images/satellite/satellite_47.616373896569506_-122.3131654557382.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74788285b6e531232c3476c854154e45985418323f63096046e4603b8012b1e3 +size 257286 diff --git a/demo_images/satellite/satellite_47.618014258832694_-122.29857212879756.png b/demo_images/satellite/satellite_47.618014258832694_-122.29857212879756.png new file mode 100644 index 0000000000000000000000000000000000000000..3a7e898f019e470c96c3eeb7bcd8c342b956ca34 --- /dev/null +++ b/demo_images/satellite/satellite_47.618014258832694_-122.29857212879756.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6404fd14cf1abc36a8bbacd2d998ad0d3678ef6cb739da2274904b2cdacffff +size 299132 diff --git a/demo_images/satellite/satellite_47.61867040373797_-122.3024636826484.png b/demo_images/satellite/satellite_47.61867040373797_-122.3024636826484.png new file mode 100644 index 0000000000000000000000000000000000000000..2d1711dc0c2aa2821dc9798e75811a96c5b687a0 --- /dev/null +++ b/demo_images/satellite/satellite_47.61867040373797_-122.3024636826484.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d00e324b01d01699c391345ebf6897898057b0e9dccc8d2cb692db1e1382287 +size 264463 diff --git a/demo_images/satellite/satellite_47.62227920071699_-122.29176190955859.png b/demo_images/satellite/satellite_47.62227920071699_-122.29176190955859.png new file mode 100644 index 0000000000000000000000000000000000000000..fca366876489ebc3aabd04bb201d5d3ed599efd6 --- /dev/null +++ b/demo_images/satellite/satellite_47.62227920071699_-122.29176190955859.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c5cb66d6eae2c47b2f5b593d33b143b6b8aa9d7653f565e07f13fc14623566e +size 282461 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..142fc986e77fc8c7f0c800586ac6523335576244 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +torch +torchvision +timm +einops +scikit-image +PyYAML +easydict +kornia +transformers==4.57.1 +diffusers +peft +PyMCubes +numpy>=1.24,<2.0 +opencv-python-headless>4.8.1 +open3d +trimesh +gradio>=4.0 +pandas>=2.1.0,<2.4.0 +xformers diff --git a/source/__init__.py b/source/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/source/__pycache__/__init__.cpython-310.pyc b/source/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0155cef3a72d3d6fb7ca31b421dd73e8d50a72e9 Binary files /dev/null and b/source/__pycache__/__init__.cpython-310.pyc differ diff --git a/source/__pycache__/__init__.cpython-38.pyc b/source/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..705afb6238c2709cb7c7acbc3e2c4c0f7bddff4b Binary files /dev/null and b/source/__pycache__/__init__.cpython-38.pyc differ diff --git a/source/__pycache__/__init__.cpython-39.pyc b/source/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a8d7992dc467d23e6320505776d8fa4c1636cc4 Binary files /dev/null and b/source/__pycache__/__init__.cpython-39.pyc differ diff --git a/source/__pycache__/dino_3d_model.cpython-310.pyc b/source/__pycache__/dino_3d_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e11b7aa27fe01fcb81ba77cba7afd1bc54eb9d39 Binary files /dev/null and b/source/__pycache__/dino_3d_model.cpython-310.pyc differ diff --git a/source/__pycache__/dino_3d_model.cpython-38.pyc b/source/__pycache__/dino_3d_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94a434c7bd688290688521784e52fc1b450f3080 Binary files /dev/null and b/source/__pycache__/dino_3d_model.cpython-38.pyc differ diff --git a/source/__pycache__/dino_3d_model.cpython-39.pyc b/source/__pycache__/dino_3d_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae495f5a3e7a9f443b3b1296acfdf896e87abe17 Binary files /dev/null and b/source/__pycache__/dino_3d_model.cpython-39.pyc differ diff --git a/source/__pycache__/discriminators.cpython-310.pyc b/source/__pycache__/discriminators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e8e4f48f5b8b1a1931081e6681b7b8a4164b017 Binary files /dev/null and b/source/__pycache__/discriminators.cpython-310.pyc differ diff --git a/source/__pycache__/generator.cpython-310.pyc b/source/__pycache__/generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b7d1d4eb8fe858a9eccaf0c063da2a8c6766be Binary files /dev/null and b/source/__pycache__/generator.cpython-310.pyc differ diff --git a/source/__pycache__/sky_histogram.cpython-310.pyc b/source/__pycache__/sky_histogram.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ff96c88ae007d12f2bfa88609ac8832ec87c449 Binary files /dev/null and b/source/__pycache__/sky_histogram.cpython-310.pyc differ diff --git a/source/__pycache__/sr_module.cpython-310.pyc b/source/__pycache__/sr_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7d11287037dfa95bd0b3048f53bd326994e1eae Binary files /dev/null and b/source/__pycache__/sr_module.cpython-310.pyc differ diff --git a/source/__pycache__/sr_module.cpython-38.pyc b/source/__pycache__/sr_module.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6c565bfd91fac656596fa1ebc929f6909a7fffa Binary files /dev/null and b/source/__pycache__/sr_module.cpython-38.pyc differ diff --git a/source/__pycache__/training_utils.cpython-310.pyc b/source/__pycache__/training_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c00916acdc15d1bceae1c3ae6f9e47d39c973a95 Binary files /dev/null and b/source/__pycache__/training_utils.cpython-310.pyc differ diff --git a/source/__pycache__/unet_hacked.cpython-310.pyc b/source/__pycache__/unet_hacked.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f20e60ba16eb2827423b6ff2df5f7bc03c77427a Binary files /dev/null and b/source/__pycache__/unet_hacked.cpython-310.pyc differ diff --git a/source/__pycache__/unet_hacked.cpython-38.pyc b/source/__pycache__/unet_hacked.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51dd676203b0ad55dbcf866d5027def9d5d81679 Binary files /dev/null and b/source/__pycache__/unet_hacked.cpython-38.pyc differ diff --git a/source/__pycache__/vae_hacked.cpython-310.pyc b/source/__pycache__/vae_hacked.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18a9ea82e048e1823a03608f81002fe2c4a66404 Binary files /dev/null and b/source/__pycache__/vae_hacked.cpython-310.pyc differ diff --git a/source/__pycache__/vae_hacked.cpython-38.pyc b/source/__pycache__/vae_hacked.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8159649ee81ed9c768ca341c69d2160d5b95c44d Binary files /dev/null and b/source/__pycache__/vae_hacked.cpython-38.pyc differ diff --git a/source/__pycache__/xyz2thetaphi.cpython-310.pyc b/source/__pycache__/xyz2thetaphi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d4d73cb1bbc9bd89ea1c07d74da8ee5107effbb Binary files /dev/null and b/source/__pycache__/xyz2thetaphi.cpython-310.pyc differ diff --git a/source/__pycache__/xyz2thetaphi.cpython-38.pyc b/source/__pycache__/xyz2thetaphi.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9860caabf7f12cd31a4e905539f52850ddb9160c Binary files /dev/null and b/source/__pycache__/xyz2thetaphi.cpython-38.pyc differ diff --git a/source/discriminators.py b/source/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f5aec81b68322a7bf290876f869b1bcf172e2a --- /dev/null +++ b/source/discriminators.py @@ -0,0 +1,53 @@ +from math import log2 +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +import torch.nn.functional as F + + +# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py +class Discriminator(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, img_resolution=256): + super().__init__() + depth = int(log2(img_resolution) -2) + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm( + nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) + ), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = hidden_channels // (2 ** max((d - i), 0)) + c_out = hidden_channels // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d( + (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1 + ) + # self.fc = nn.Linear(hidden_channels * int((img_resolution/(2**depth))**2), + # hidden_channels,) + # self.out = nn.Linear(hidden_channels,1) + # self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view( + cond.size(0), + cond.size(1), + 1, + 1, + ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = x.flatten(1).mean(-1) + # x = self.fc(x.flatten(1)) + # x = F.leaky_relu(x, 0.2) + # x = self.out(x) + return x diff --git a/source/generator.py b/source/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f4e5471b1b09d43590e0a87c4111e6713164c9 --- /dev/null +++ b/source/generator.py @@ -0,0 +1,996 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# from diffusers import DDIMScheduler +import numpy as np +import math + +# from diffusers import StableDiffusionPipeline, DDIMScheduler +# from pytorch_lightning import LightningModule, Trainer + + +from transformers import AutoConfig, AutoModel +from transformers.modeling_utils import no_init_weights + +import warnings + +warnings.filterwarnings("ignore") + +import os +from pathlib import Path +from einops import rearrange, repeat + +from easydict import EasyDict as edict + +from source.vae_hacked import Decoder +from source.rendering.utils import sample_importance,unify_attributes, create_voxel +from source.rendering.point_representer import PointRepresenter +from source.rendering.point_integrator import PointIntegrator +from source.rendering.sat2density_transform_eg3d import get_original_coord,Point_sampler_pano,Point_sampler_ortho +from source.rendering.transform_perspective import PointSamplerPerspective +from source.rendering.mlp_model import MLPNetwork2 +from source.sr_module import SuperresolutionHybrid2X +from source.xyz2thetaphi import xyz2thetaphi + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +import tqdm + +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + + +def resolve_backbone_candidates(backbone): + env_override_map = { + "dinov2-base": "SAT3DGEN_DINOV2_BASE_PATH", + "dinov2-large": "SAT3DGEN_DINOV2_LARGE_PATH", + "dinov3-large-sat": "SAT3DGEN_DINOV3_SAT_PATH", + "dinov3-large-lvd": "SAT3DGEN_DINOV3_LVD_PATH", + } + default_candidate_map = { + "dinov2-base": [ + "facebook/dinov2-base", + ], + "dinov2-large": [ + "facebook/dinov2-large", + ], + "dinov3-large-sat": [ + "facebook/dinov3-vitl16-pretrain-sat493m", + ], + "dinov3-large-lvd": [ + "facebook/dinov3-vitl16-pretrain-lvd1689m", + ], + } + if backbone not in default_candidate_map: + raise NotImplementedError(f"Unsupported backbone: {backbone}") + + candidates = [] + env_override = os.environ.get(env_override_map[backbone]) + if env_override: + candidates.append(env_override) + candidates.extend(default_candidate_map[backbone]) + return candidates + + +# Built-in backbone configs so we can create the model structure without any +# network access (the gated DINOv3 repos require authentication even to fetch +# the config file). These are architecture-only settings and never change. +_BACKBONE_CONFIGS = { + "dinov2-base": { + "model_type": "dinov2", + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + "patch_size": 14, + "image_size": 518, + "num_channels": 3, + "num_register_tokens": 0, + }, + "dinov2-large": { + "model_type": "dinov2", + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "intermediate_size": 4096, + "patch_size": 14, + "image_size": 518, + "num_channels": 3, + "num_register_tokens": 0, + }, + "dinov3-large-sat": { + "model_type": "dinov3_vit", + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "intermediate_size": 4096, + "patch_size": 16, + "image_size": 224, + "num_channels": 3, + "num_register_tokens": 4, + "hidden_act": "gelu", + "attention_dropout": 0.0, + "drop_path_rate": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-05, + "layerscale_value": 1.0, + "key_bias": False, + "mlp_bias": True, + "proj_bias": True, + "query_bias": True, + "value_bias": True, + "use_gated_mlp": False, + "rope_theta": 100.0, + "pos_embed_rescale": 2.0, + }, + "dinov3-large-lvd": { + "model_type": "dinov3_vit", + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "intermediate_size": 4096, + "patch_size": 16, + "image_size": 224, + "num_channels": 3, + "num_register_tokens": 4, + "hidden_act": "gelu", + "attention_dropout": 0.0, + "drop_path_rate": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-05, + "layerscale_value": 1.0, + "key_bias": False, + "mlp_bias": True, + "proj_bias": True, + "query_bias": True, + "value_bias": True, + "use_gated_mlp": False, + "rope_theta": 100.0, + "pos_embed_rescale": 2.0, + }, +} + +def load_backbone_model(backbone, skip_weights=False): + """Load (or create) the backbone vision model. + + When *skip_weights* is ``True`` the model structure is instantiated + from a built-in config dict **without** any network access. This is + useful when the caller will overwrite all parameters later (e.g. via + ``Sat3DGen.from_pretrained``), avoiding a redundant multi-GB + download of the backbone checkpoint. + """ + if skip_weights: + if backbone not in _BACKBONE_CONFIGS: + raise NotImplementedError(f"No built-in config for backbone: {backbone}") + print(f"Creating backbone structure from built-in config (skip weights): {backbone}") + config = AutoConfig.for_model(**_BACKBONE_CONFIGS[backbone]) + with no_init_weights(): + model = AutoModel.from_config(config) + return model.eval().requires_grad_(False) + + load_errors = [] + for candidate in resolve_backbone_candidates(backbone): + expanded_candidate = os.path.expanduser(candidate) + resolved_candidate = expanded_candidate if Path(expanded_candidate).exists() else candidate + try: + print("Trying pretrained_model_name_or_path:", resolved_candidate) + return AutoModel.from_pretrained(resolved_candidate).eval().requires_grad_(False) + except Exception as exc: + load_errors.append(f"{resolved_candidate}: {exc}") + + formatted_errors = "\n".join(load_errors) + raise RuntimeError( + f"Failed to load the backbone `{backbone}`.\n" + f"Tried the following candidates:\n{formatted_errors}\n" + "You can override the lookup with the corresponding SAT3DGEN_*_PATH environment variable." + ) + +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + w_dim, # Intermediate latent (W) dimensionality. + num_layers = 8, # Number of mapping layers. + norm = True, + ): + super().__init__() + self.z_dim = z_dim + self.w_dim = w_dim + self.num_layers = num_layers + self.norm = norm + + features_list = [z_dim] * (num_layers) + [w_dim] + layers = [] + for idx in range(num_layers): + layers.append(nn.Linear(features_list[idx], features_list[idx + 1])) + layers.append(nn.LeakyReLU(0.2)) + self.mapping = nn.Sequential(*layers) + + def forward(self, z): + # Embed, normalize, and concat inputs. + if self.norm: + z = normalize_2nd_moment(z.to(torch.float32)) # normalize z to sphere + + x = self.mapping(z) + return x + +class dino_3d_model(nn.Module): + def __init__(self,output_ch=192,ch_mult=[1,2,4,4,4],pad = False, with_attn=True,backbone='dinov2-base',no_hidden_states=False, no_cls_token=False, skip_backbone_weights=False): + super().__init__() + self.dino_model = load_backbone_model(backbone, skip_weights=skip_backbone_weights) + if backbone == 'dinov2-base': + z_channels = 6144 if not no_cls_token else 6144//2 + + self.feature_list = [3,6,9,12] if not no_hidden_states else [] + if self.feature_list == []: + z_channels = z_channels//4 + elif backbone in ['dinov2-large',"dinov3-large-sat","dinov3-large-lvd"]: + z_channels = 8192 if not no_cls_token else 8192//2 + self.feature_list = [6,12,18,24] if not no_hidden_states else [] + if self.feature_list == []: + z_channels = z_channels//4 + self.backbone = backbone + self.no_cls_token = no_cls_token + self.decoder = Decoder(ch=128,out_ch=output_ch,ch_mult=ch_mult,num_res_blocks=2,attn_resolutions=[],z_channels=z_channels,resolution=256,in_channels=None,with_attn=with_attn) + self.pad = pad + self.patch_size = self.dino_model.config.to_dict()['patch_size'] + self.num_register_tokens = self.dino_model.config.to_dict()['num_register_tokens'] if 'num_register_tokens' in self.dino_model.config.to_dict().keys() else 0 + + def forward(self, inputs): + _h,_w = inputs.shape[-2:] + assert _h == 16 * self.patch_size + output = self.dino_model(inputs,output_hidden_states=True) + out_put_list = [] + if self.feature_list == []: + out_put_list.append(output.last_hidden_state) + else: + for i in self.feature_list: + out_put_list.append(output.hidden_states[i]) + # a mistake, because len(output.hidden_states) is 13 + # our last feature list is 12, + # and last_hidden_state is layer normed output.hidden_states[-1], + # so we should not append it to out_put_list + x = torch.cat(out_put_list,dim=2) + dino_feature = rearrange(x[:,1+self.num_register_tokens:], 'b (h w) c -> b c h w', h=_h//self.patch_size , w=_w//self.patch_size) + if not self.no_cls_token: + cls_token = x[:,0] + dino_feature = torch.cat([dino_feature, cls_token.unsqueeze(-1).unsqueeze(-1).repeat(1,1,_h//self.patch_size,_w//self.patch_size)],dim=1) # [2, 7680, 16, 16] + # noise = torch.randn_like(dino_feature) + if self.pad: + ori_size = dino_feature.size(-1) + pad_size = ori_size*self.pad + # make sure pad size is int + assert pad_size == int(pad_size), 'pad_size should be int' + pad_size = int(pad_size) + dino_feature = F.pad(dino_feature,(pad_size,pad_size,pad_size,pad_size),'constant', 0) + output = self.decoder(dino_feature) # 320*320 # 4 time 2x upsampling + return output + + + + +def convert_to_easydict(d): + if isinstance(d, dict): + return edict({k: convert_to_easydict(v) for k, v in d.items()}) + return d + +class Sat3DGen(ModelMixin, ConfigMixin): + # When True, skip downloading pretrained backbone weights during __init__. + # The weights will be loaded later by from_pretrained() from the full + # checkpoint (which already contains the backbone parameters), avoiding + # a redundant multi-GB download of the standalone backbone model. + _skip_backbone_weights: bool = False + + @register_to_config + def __init__(self, opt): + super().__init__() + self.opt = opt + + # if opt is not a edict object, convert it to edict object . + self.opt = convert_to_easydict(opt) + if 'sr_padding_mode' not in self.opt.keys(): + self.opt.sr_padding_mode = 'zeros' + if 'representation_type' not in self.opt.keys(): + self.opt.representation_type = 'triplane' + self.sat_mapping_mode = 'v2' if not hasattr(self.opt.network, 'sat_mapping_mode') else self.opt.network.sat_mapping_mode + assert self.sat_mapping_mode in ['v2'], 'sat_mapping_mode should be v1 or v2' + self.sr_factor = 1 if not hasattr(self.opt.network, 'sr_factor') else self.opt.network.sr_factor + self.if_w_sky_mapping = True + self.backbone = 'dinov2-base' if not hasattr(self.opt, 'backbone') else self.opt.backbone + if self.if_w_sky_mapping: + self.z_dim = 270 + self.w_dim = 512 + self.sky_mapping = MappingNetwork(self.z_dim,self.w_dim,norm=False) + else: + self.z_dim = 270 + self.w_dim = 270 + assert self.sr_factor in [1,2] , 'sr_factor should be 1 or 2' + self.image_size = self.opt.network.image_size # not used + self.latent_size = self.opt.network.latent_size + self.latent_channel = self.opt.network.latent_channel + if 'pad' in self.opt.keys(): + self.pad = self.opt.pad + self.position_scale_factor = 1 / (self.pad*2+1) + assert self.opt.network.position_scale_factor ==1, 'position_scale_factor should be 1,not used in this version.' + else: + self.position_scale_factor = self.opt.network.position_scale_factor + self.pad = False + + color_channels = 32 if not hasattr(self.opt.network, 'color_channels') else self.opt.network.color_channels + self.sr_module = SuperresolutionHybrid2X(color_channels, 3,padding_mode=self.opt.sr_padding_mode,v2=True) + if self.opt.representation_type == 'triplane': + output_ch = self.opt.network.triplane.dim*3 + elif self.opt.representation_type in ['oneplane','oneplane_multi']: + output_ch = self.opt.network.triplane.dim*2 + self.with_sky = True + self.sky_input_dim = 2 + + if self.with_sky: + self.sky_decoder = Decoder(ch=32,out_ch=color_channels,ch_mult=[1,2,2,4,4,4,4],num_res_blocks=2,attn_resolutions=[],z_channels=self.w_dim ,resolution=256,in_channels=None,with_attn=False,pano_pad=True) + self.unet_model = dino_3d_model(output_ch = output_ch, + ch_mult = self.opt.network.triplane.ch_mult if hasattr(self.opt.network.triplane, 'ch_mult') else [1,2,4,4,4], + pad = self.pad, + with_attn = self.opt.network.with_attn if hasattr(self.opt.network, 'with_attn') else True, + backbone = self.backbone, + no_hidden_states=self.opt.network.no_hidden_states if hasattr(self.opt.network, 'no_hidden_states') else False, + no_cls_token=self.opt.network.no_cls_token if hasattr(self.opt.network, 'no_cls_token') else False, + skip_backbone_weights=self._skip_backbone_weights, + ) + + self.num_importance = self.opt.network.point_sampling_kwargs.num_importance + # delete num_importance from self.opt.network.point_sampling_kwargs. + self.opt.network.point_sampling_kwargs.pop('num_importance') + if self.opt.representation_type == 'oneplane': + input_dim_mlp = self.opt.network.triplane.dim*2 + elif self.opt.representation_type in ['triplane','oneplane_multi']: + input_dim_mlp = self.opt.network.triplane.dim + self.mlp = MLPNetwork2(input_dim=input_dim_mlp, + hidden_dim=64, + output_dim=color_channels, + style_dim=self.w_dim, + ) + + self.point_representer = PointRepresenter( + representation_type=self.opt.representation_type, + triplane_axes=None, + coordinate_scale=None, + ) + self.point_integrator = PointIntegrator(**self.opt.network.ray_marching_kwargs) + unused_parameter = ['max_height','origin_height','realworld_scale'] + for i in unused_parameter: + if i in self.opt.network.point_sampling_kwargs.keys(): + self.opt.network.point_sampling_kwargs.pop(i) + if self.sr_factor ==2: + self.opt.render_size = 256 + self.point_sampler_definition(self.opt.render_size if hasattr(self.opt, 'render_size') else 256) + + def point_sampler_definition(self, render_size=256): + pano_size = np.array([render_size*2,render_size//2]) / self.sr_factor + self.pano_direction = get_original_coord(W=int(pano_size[0]),H=int(pano_size[1]),full=True).unsqueeze(0).cuda().float() # different with the original one + # point_sampling_kwargs.pano_direction = + self.point_sampler = Point_sampler_pano(pano_direction=self.pano_direction,**self.opt.network.point_sampling_kwargs) + self.point_sampler_per = PointSamplerPerspective(num_points=self.opt.network.point_sampling_kwargs.num_points,aabb_strict=True,render_size=[render_size// self.sr_factor,render_size// self.sr_factor]) + if render_size==256 and self.sr_factor == 2: + self.point_sampler_sat = Point_sampler_ortho(num_points=self.opt.network.point_sampling_kwargs.num_points,position_scale_factor=self.position_scale_factor,render_size=render_size// self.sr_factor) + else: + self.point_sampler_sat = Point_sampler_ortho(num_points=self.opt.network.point_sampling_kwargs.num_points,position_scale_factor=self.position_scale_factor,resolution=int(render_size*1.5),render_size=render_size) + print('render size:', render_size, 'sr_factor:', self.sr_factor) + + + def from_sat_to_triplane(self,x): + planes_feature = self.unet_model(x) + if self.opt.representation_type == 'triplane': + triplane_ori = rearrange(planes_feature, 'b (n c) h w -> b n c h w',n=3) + elif self.opt.representation_type in ['oneplane','oneplane_multi']: + one_plane_ori = planes_feature[:,:self.opt.network.triplane.dim] + one_plane_ori = rearrange(one_plane_ori, 'b (n c) h w -> b n c h w',n=1) + one_line_ori = planes_feature[:,self.opt.network.triplane.dim:] + one_line_ori = torch.mean(one_line_ori, dim=2, keepdim=False) + triplane_ori = [one_plane_ori,one_line_ori] + return triplane_ori + + def c2w_prepare(self, c2w): + if c2w is not None: + c2w[:,:3, 3] = c2w[:,:3, 3] * self.position_scale_factor + return c2w + + def w_sky_prepare(self, z_ill): + if z_ill is not None: + if self.if_w_sky_mapping: + w_sky = self.sky_mapping(z_ill) + else: + w_sky = z_ill + else: + w_sky = None + return w_sky + + def w_sky2sky_feature_2D(self, w_sky, z_ill=None): + sky_feature_2D = None + if self.with_sky and z_ill is not None: + sky_feature_2D = repeat(w_sky, 'b c -> b c h w', h=8, w=8) + sky_feature_2D = self.sky_decoder(sky_feature_2D) + sky_feature_2D = torch.sigmoid(sky_feature_2D) + # pad to full panorama width + b,c,h,w = sky_feature_2D.shape + zero_pad_sky = torch.zeros((b,c,h,int(w*0.8)),device=sky_feature_2D.device) + sky_feature_2D = torch.cat([sky_feature_2D,zero_pad_sky],dim=3) + return sky_feature_2D + + def from_3D_to_results(self, + triplane_ori, + c2w=None, + w_sky=None, + sky_feature_2D=None, + syn_sat=False, + random_sat_crop=True, + syn_pano=True, + syn_per=False, + same_histo=False, + intrinsics=None, + coordinates=None): + results = edict() + point_sampling_result = [] + w_list = [] + syn_sign = [] + if type(triplane_ori) is list: + N = triplane_ori[0].shape[0] + else: + N = triplane_ori.shape[0] + # triplane_ori_repeat = triplane_ori.repeat(2,1,1,1,1) + + + if syn_sat: + point_sampling_result_sat = self.point_sampler_sat(batch_size=N,random_crop=random_sat_crop,crop_type='crop') + point_sampling_result.append(point_sampling_result_sat) + if not same_histo: + w_sat = torch.zeros([N,self.w_dim], device=triplane_ori.device if type(triplane_ori) is not list else triplane_ori[0].device) + else: + w_sat = w_sky + + w_list.append(w_sat) + syn_sign.append('sat') + + if syn_pano: + resize_for_pano = False + point_sampling_result_pano = self.point_sampler(batch_size=N,position=c2w[:,:3, 3]) + if self.training: + if point_sampling_result_pano.rays_world.size(1) != point_sampling_result_pano.rays_world.size(2): + resize_for_pano = True + # rearrange from [4, 64, 256, 3] to [4, 64*2, 128/2, 3] + point_sampling_result_pano.rays_world = rearrange(point_sampling_result_pano.rays_world, 'b h (w d) c -> b (h d) w c', d=2) + point_sampling_result_pano.ray_origins = rearrange(point_sampling_result_pano.ray_origins, 'b h (w d) c -> b (h d) w c', d=2) + point_sampling_result_pano.points_world = rearrange(point_sampling_result_pano.points_world, 'b h (w d) n c -> b (h d) w n c', d=2) + point_sampling_result_pano.radii = rearrange(point_sampling_result_pano.radii, 'b h (w d) c -> b (h d) w c', d=2) + point_sampling_result.append(point_sampling_result_pano) + w_list.append(w_sky) + syn_sign.append('pano') + if syn_per: + point_sampling_result_per = self.point_sampler_per(intrinsics=intrinsics, c2w=c2w) + point_sampling_result.append(point_sampling_result_per) + w_list.append(w_sky) + syn_sign.append('pespective') + + if self.training and len(point_sampling_result) >1: + point_sampling_result_cat = edict() + point_sampling_result_cat.rays_world = torch.cat([i.rays_world for i in point_sampling_result],dim=0) + point_sampling_result_cat.ray_origins = torch.cat([i.ray_origins for i in point_sampling_result],dim=0) + point_sampling_result_cat.points_world = torch.cat([i.points_world for i in point_sampling_result],dim=0) + point_sampling_result_cat.radii = torch.cat([i.radii for i in point_sampling_result],dim=0) + + w_input = torch.cat(w_list,dim=0) + if self.opt.representation_type == 'triplane': + feature_input = triplane_ori.repeat(len(point_sampling_result),1,1,1,1) + + elif self.opt.representation_type in ['oneplane','oneplane_multi']: + feature_input = [triplane_ori[0].repeat(len(point_sampling_result),1,1,1,1),triplane_ori[1].repeat(len(point_sampling_result),1,1)] + output = self.from_point_sampling2result(point_sampling_result_cat, + feature_input, + w_sky=w_input, + ) + else: + for i in range(len(point_sampling_result)): + if syn_sign[i] == 'sat': + results.sat_output = self.from_point_sampling2result(point_sampling_result[i], + triplane_ori, + w_sky=w_list[i], + ) + elif syn_sign[i] == 'pano': + results.str_output = self.from_point_sampling2result(point_sampling_result[i], + triplane_ori, + w_sky=w_list[i], + ) + elif syn_sign[i] == 'pespective': + results.per_output = self.from_point_sampling2result(point_sampling_result[i], + triplane_ori, + w_sky=w_list[i], + ) + if 'sat' in syn_sign: + if self.training and len(point_sampling_result) >1: + results.sat_output = edict() + results.sat_output.feature_raw = output.feature_raw[:N] + results.sat_output.alpha_raw = output.alpha_raw[:N] + results.sat_output.image_depth = output.image_depth[:N] + results.sat_output.image_radii = output.image_radii[:N] + if 'idx' in point_sampling_result_sat.keys(): + results.sat_output.idx = point_sampling_result_sat['idx'] + + if 'pano' in syn_sign: + if self.training and len(point_sampling_result) >1: + results.str_output = edict() + results.str_output.feature_raw = output.feature_raw[N:2*N] + results.str_output.alpha_raw = output.alpha_raw[N:2*N] + results.str_output.image_depth = output.image_depth[N:2*N] + results.str_output.image_radii = output.image_radii[N:2*N] + + if resize_for_pano: + results.str_output.feature_raw = rearrange(results.str_output.feature_raw, 'b c (h d) w -> b c h (w d)', d=2) + results.str_output.alpha_raw = rearrange(results.str_output.alpha_raw, 'b c (h d) w -> b c h (w d)', d=2) + results.str_output.image_depth = rearrange(results.str_output.image_depth, 'b c (h d) w -> b c h (w d)', d=2) + results.str_output.image_radii = rearrange(results.str_output.image_radii, 'b c (h d) w -> b c h (w d)', d=2) + if 'idx' in point_sampling_result_pano.keys(): + results.str_output.idx = point_sampling_result_pano['idx'] + + results.str_output.ray_direction = point_sampling_result_pano.rays_world + # render sky + if self.with_sky: + ray_direction = xyz2thetaphi(results.str_output.ray_direction) + sky_img = F.grid_sample(sky_feature_2D, ray_direction,align_corners=True) + sky_img = torch.clamp(sky_img, 0, 1) + if resize_for_pano: + sky_img = rearrange(sky_img, 'b c (h d) w -> b c h (w d)', d=2) + rgb_feature_compo = results.str_output.feature_raw * results.str_output.alpha_raw + sky_img * (1 - results.str_output.alpha_raw) + results.str_output.sky_img = sky_img + results.str_output.image_raw_compo = rgb_feature_compo + if self.sr_factor == 2: + results.str_output.sr_image = self.sr_module(rgb_feature_compo) + + if 'pespective' in syn_sign: + if self.training and len(point_sampling_result) >1: + results.per_output = edict() + results.per_output.feature_raw = output.feature_raw[-N:] + results.per_output.alpha_raw = output.alpha_raw[-N:] + results.per_output.image_depth = output.image_depth[-N:] + results.per_output.image_radii = output.image_radii[-N:] + + + if 'idx' in point_sampling_result_per.keys(): + results.per_output.idx = point_sampling_result_per['idx'] + + results.per_output.ray_direction = point_sampling_result_per.rays_world + + # render sky + if self.with_sky: + ray_direction = xyz2thetaphi(results.per_output.ray_direction) + sky_img = F.grid_sample(sky_feature_2D, ray_direction,align_corners=True) + sky_img = torch.clamp(sky_img, 0, 1) + rgb_feature_compo = results.per_output.feature_raw * results.per_output.alpha_raw + sky_img * (1 - results.per_output.alpha_raw) + results.per_output.sky_img = sky_img + results.per_output.image_raw_compo = rgb_feature_compo + if self.sr_factor == 2: + results.per_output.sr_image = self.sr_module(rgb_feature_compo) + + if coordinates is not None: + # for density regularization + results.density = self.density_reg(coordinates,triplane_ori) + return results + + def density_reg(self,coordinates,triplane_ori,sample_color=False,w_sky=None): + # Only for density regularization in training process. + assert coordinates is not None + sample_result = self.sample_mixed(coordinates, + triplane_ori, + sample_color=sample_color, + w_sky=w_sky, + ) + sample_density = sample_result['density'] + color_result = sample_result['color'][...,:3] if sample_color==True else None + if self.opt.network.ray_marching_kwargs.density_clamp_mode == 'mipnerf': + sample_density = F.softplus(sample_density - 1) + elif self.opt.network.ray_marching_kwargs.density_clamp_mode == 'relu': + sample_density = F.relu(sample_density + 3) + else: + raise NotImplementedError + if sample_color: + return color_result + return sample_density + + + def forward(self, + x, + z_ill=None, + syn_sat=False, + random_sat_crop=True, + syn_pano=True, + syn_per=False, + same_histo=False, + intrinsics=None, + c2w=None, + coordinates=None, + ): + c2w = self.c2w_prepare(c2w) + + triplane_ori = self.from_sat_to_triplane(x) + + w_sky = self.w_sky_prepare(z_ill) + + sky_feature_2D = self.w_sky2sky_feature_2D(w_sky,z_ill) + + + results = self.from_3D_to_results(triplane_ori, + c2w, + w_sky, + sky_feature_2D, + syn_sat=syn_sat, + random_sat_crop=random_sat_crop, + syn_pano=syn_pano, + syn_per=syn_per, + same_histo=same_histo, + intrinsics=intrinsics, + coordinates=coordinates) + results.triplane = triplane_ori + + return results + + + + + + + def sample_mixed(self, + coordinates, + triplanes, + sample_color=False, + w_sky=None + ): + + + + point_features = self.point_representer( + coordinates, ref_representation=triplanes) + color_density_result = self.mlp(point_features,only_density=not sample_color,style=w_sky) # point_features: B N C + + return color_density_result + + + + def from_point_sampling2result(self, + point_sampling_result, + triplanes, + w_sky=None, + **synthesis_kwarg + ): + points = point_sampling_result['points_world'] # [N, H, W, K, 3] + ray_dirs = point_sampling_result['rays_world'] # [N, H, W, 3] + radii_coarse = point_sampling_result['radii'] # [N, H, W, K] + ray_origins = point_sampling_result['ray_origins'] # [N, 3] + + _, H, W, K, _ = points.shape + R = H * W + points_coarse = rearrange(points, 'n h w k c -> n (h w) k c') # [N, R, K, 3] + points = rearrange(points, 'n h w k c -> n (h w k) c') # [N, R * K, 3] + ray_dirs = rearrange(ray_dirs, 'n h w c -> n (h w) c') + if len(ray_origins.shape) == 4: + ray_origins = rearrange(ray_origins, 'n h w c -> n (h w) c') + elif len(ray_origins.shape) == 2: + ray_origins = repeat(ray_origins, 'n c -> n (h w) c', h=R, w=1) + radii_coarse = rearrange(radii_coarse, 'n h w k -> n (h w) k 1') + + point_features = self.point_representer( + points, ref_representation=triplanes) # [N, R * K, C] + color_density_result = self.mlp(point_features,w_sky) # point_features: B N C + + densities_coarse = color_density_result['density'] # [N, R * K, 1] + colors_coarse = color_density_result['color'] # [N, R * K, C] + densities_coarse = rearrange(densities_coarse, 'n (r k) c -> n r k c', r=R, k=K) + colors_coarse = rearrange(colors_coarse, 'n (r k) c -> n r k c', r=R, k=K) + + if self.num_importance > 0: + # Do the integration along the coarse pass. + rendering_result = self.point_integrator(colors_coarse, + densities_coarse, + radii_coarse) + weights = rendering_result['weight'] + + # Importance sampling. + radii_fine = sample_importance(radii_coarse, + weights, + self.num_importance, + smooth_weights=True) + points = ray_origins.unsqueeze( + -2) + radii_fine * ray_dirs.unsqueeze( + -2) # [N, R, num_importance, 3] + points_fine = points + points = rearrange(points, 'n r k c -> n (r k) c') # [N, R * num_importance, 3] + + point_features = self.point_representer( + points, ref_representation=triplanes) + color_density_result = self.mlp(point_features,w_sky) + + densities_fine = color_density_result['density'] + colors_fine = color_density_result['color'] + densities_fine = rearrange(densities_fine, 'n (r k) c -> n r k c', r=R, k=self.num_importance) + colors_fine = rearrange(colors_fine, 'n (r k) c -> n r k c', r=R, k=self.num_importance) + + # Gather coarse and fine results together. + (all_radiis, all_colors, all_densities, + all_points) = unify_attributes(radii_coarse, + colors_coarse, + densities_coarse, + radii_fine, + colors_fine, + densities_fine, + points1=points_coarse, + points2=points_fine) + + # Do the integration along the fine pass. + rendering_result = self.point_integrator(all_colors, + all_densities, + all_radiis) + + else: + # Only do the integration along the coarse pass. + rendering_result = self.point_integrator(colors_coarse, + densities_coarse, + radii_coarse) + # all_points = points_coarse # [N, R, K, 3] + + feature_samples = rendering_result['composite_color'] + radii_samples = rendering_result['composite_radial_dist'] + + feature_image = rearrange(feature_samples, 'n (h w) c -> n c h w', h=H, w=W).contiguous() # [N, C, H, W] + image_radii = rearrange(radii_samples, 'n (h w) c -> n c h w', h=H, w=W).contiguous() # [N, 1, H, W] + + image_alpha = rearrange(rendering_result['opacity'], 'n (h w) c -> n c h w', h=H, w=W).contiguous() + image_depth = rearrange(rendering_result['composite_radial_dist'], 'n (h w) c -> n c h w', h=H, w=W).contiguous() + + # rgb_image = feature_image[:, :3] + result = edict() + result.feature_raw = feature_image + result.alpha_raw = image_alpha + # result.image_raw = rgb_image + result.image_depth = image_depth + result.image_radii = image_radii + result.ray_origin = ray_origins + if 'idx' in point_sampling_result.keys(): + result.idx = point_sampling_result['idx'] + return result + + @torch.no_grad() + def forward_grid(self, planes, grid_size=256,position_scale_factor=1,crop=False): + max_batch = 15000000 + # size = int(grid_size/self.position_scale_factor) + voxel_grid = create_voxel(N=grid_size,position_scale_factor=1)['voxel_grid'].cuda() + densities = torch.zeros( + (voxel_grid.shape[0], voxel_grid.shape[1], 1)).cuda() + # data/CVACT/satview_correct/ + + # read img to cuda, [-1,1] + + head = 0 + with tqdm.tqdm(total=voxel_grid.shape[1]) as pbar: + with torch.no_grad(): + while head < voxel_grid.shape[1]: + density = self.density_reg(coordinates=voxel_grid[:, head:head + max_batch],triplane_ori=planes) + # density = self.forward(sat_img, + # z, + # None, + # syn_pano=False, + # coordinates=voxel_grid[:, head:head + max_batch])['density'] + # if self.opt.network.ray_marching_kwargs.density_clamp_mode == 'mipnerf': + # densities = F.softplus(densities - 1) + # else: + # raise NotImplementedError + + # density = G.sample( + # voxel_grid[:, head:head + max_batch], + # batch_codes, + # sat_img, + # **G_kwargs)['density'] + densities[:, head:head + max_batch] = density + head = head + max_batch + pbar.update(max_batch) + + densities = densities.reshape( + (grid_size, grid_size, grid_size)).cpu().numpy() + # densities = np.flip(densities, 0) + # densities = np.flip(densities, 0) + # densities = np.flip(densities, 1) + # # Trim the border of the extracted cube. + if self.position_scale_factor < 1: + pad = int(np.round(((1-self.position_scale_factor)*densities.shape[0]/2))) + if not crop: + # densities = densities[pad:-pad, pad:-pad, pad:] + # return densities + # else: + pad_value = 0 + densities[:pad] = pad_value + densities[-pad:] = pad_value + densities[:, :pad] = pad_value + densities[:, -pad:] = pad_value + densities[:, :, :pad] = pad_value # z space + else: + densities = densities[pad:-pad, pad:-pad, pad:] + return densities + + + @torch.no_grad() + def save_shape_from_sat(self, sat_img, position_scale_factor=1,crop=False,grid_size=320): + planes = self.from_sat_to_triplane(sat_img) + + return self.forward_grid(planes,position_scale_factor=1,crop=crop,grid_size=grid_size) + + @torch.no_grad() + def save_shape(self, planes,position_scale_factor=1,save_type='density',crop=False): + densities = self.forward_grid(planes,position_scale_factor=position_scale_factor) + + if save_type == 'density': + try: + import mrcfile + except ImportError: + raise ImportError("mrcfile is required for density export. Install via: pip install mrcfile") + with mrcfile.new_mmap(f'0000.mrc', + overwrite=True, + shape=densities.shape, + mrc_mode=2) as mrc: + mrc.data[:] = densities + print('save density done') + + try: + import open3d as o3d + except ImportError: + raise ImportError("open3d is required for 3D shape export. Install via: pip install open3d") + if save_type == 'mesh': + from skimage import measure + import trimesh + # Extract a mesh with Marching Cubes. + verts, faces, _, _ = measure.marching_cubes(densities, level=4.5) + + + # Build the Trimesh object. + mesh = trimesh.Trimesh(vertices=verts, faces=faces) + + # Compute vertex normals. + mesh.vertex_normals + + # Optional mesh visualization. + # mesh.show() + + # Export the mesh as a PLY file. + mesh.export('mesh.ply') + + + + + if save_type in ['pointcloud','voxel']: + + + + def efficient_filter_numpy(densities, threshold=5): + size = densities.shape[0] + + # Mark voxels whose density is above the threshold. + high_density = np.where(densities >= threshold, 1, 0) + + # Count high-density voxels along each local axis. + x_sum = high_density[:-2, 1:-1, 1:-1] +high_density[1:-1, 1:-1, 1:-1] + high_density[2:, 1:-1, 1:-1] + y_sum = high_density[1:-1, :-2, 1:-1] + high_density[1:-1, 1:-1, 1:-1] + high_density[1:-1, 2:, 1:-1] + z_sum = high_density[1:-1, 1:-1, :-2] + high_density[1:-1, 1:-1, 1:-1] + high_density[1:-1, 1:-1, 2:] + # Keep only voxels that satisfy all local support conditions. + mask = (x_sum == 3) & (y_sum == 3) & (z_sum == 3) + + # Remove voxels that pass the mask. + densities[1:-1, 1:-1, 1:-1][mask] = 0 + + return densities + + # print the number of voxels >= 5 + print('the number of voxels >= 5 before filtering:', np.sum(densities >= 5)) + densities = efficient_filter_numpy(densities) + # print the number of voxels >= 5 after filtering + print('the number of voxels >= 5 after filtering:', np.sum(densities >= 5)) + + + + + points = np.array(np.where(densities >= 5)).T + points = (points / size) *2 - 1 + + point_cloud = o3d.geometry.PointCloud() + point_cloud.points = o3d.utility.Vector3dVector(points) + + + + # def position_to_color(points): + # # Map point coordinates from [-1, 1] to [0, 1]. + # normalized_points = (points + 1) / 3 + + # # Use x, y, z as r, g, b. + # colors = normalized_points + + # # A more complex color mapping is also possible. + # # colors = np.column_stack([ + # # normalized_points[:, 0], # r from x + # # (normalized_points[:, 1] + normalized_points[:, 2]) / 2, # g from (y+z)/2 + # # 1 - normalized_points[:, 2] # b from 1-z + # # ]) + + # return colors + # colors = position_to_color(np.asarray(point_cloud.points)) + # point_cloud.colors = o3d.utility.Vector3dVector(colors) + + + if save_type == 'pointcloud': + # save point cloud + o3d.io.write_point_cloud("point_cloud.ply", point_cloud) + + if save_type == 'voxel': + voxel_size = (1 / size)* 2 + voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(point_cloud, voxel_size) + o3d.io.write_voxel_grid("voxel_grid.ply", voxel_grid) + + # elif save_type == 'mesh': + # mesh = voxel_grid.to_mesh() + # o3d.io.write_triangle_mesh("mesh.ply", mesh) + + # print(xyz.shape) + + return 0 + + + def extract_mesh( + self, + planes: torch.Tensor, + mesh_resolution: int = 320, + mesh_threshold: int = 5.0, + w_sky = None, + **kwargs, + ): + ''' + Extract a 3D mesh from triplane nerf. Only support batch_size 1. + :param planes: triplane features + :param mesh_resolution: marching cubes resolution + :param mesh_threshold: iso-surface threshold + ''' + print('mesh_resolution:', mesh_resolution) + device = planes.device if type(planes) is not list else planes[0].device + + grid_out = self.forward_grid( + planes=planes, + grid_size=mesh_resolution, + ) + try: + import mcubes + except ImportError: + raise ImportError("PyMCubes is required for mesh extraction. Install via: pip install PyMCubes") + vertices, faces = mcubes.marching_cubes( + grid_out, + mesh_threshold, + ) + vertices = vertices / (mesh_resolution - 1) * 2 - 1 + # query vertex colors + vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) + vertices_colors = self.density_reg(vertices_tensor,planes,sample_color=True,w_sky=w_sky) + vertices_colors = (vertices_colors * 255).squeeze(0).cpu().numpy().astype(np.uint8) + return vertices, faces, vertices_colors + + + +class EMANorm(nn.Module): + def __init__(self, beta): + super().__init__() + self.register_buffer('magnitude_ema', torch.ones([])) + self.beta = beta + + def forward(self, x): + if self.training: + magnitude_cur = x.detach().to(torch.float32).square().mean() + self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.beta)) + input_gain = self.magnitude_ema.rsqrt() + x = x.mul(input_gain) + return x + + +# Backward-compatible alias so that existing config.json files with +# "_class_name": "VAE_finetune" (e.g. on HuggingFace) keep working. +VAE_finetune = Sat3DGen + + diff --git a/source/losses/__pycache__/density_tv_loss.cpython-310.pyc b/source/losses/__pycache__/density_tv_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd516d7f9bb8e49494e7bfdffa40a4bb9db1e757 Binary files /dev/null and b/source/losses/__pycache__/density_tv_loss.cpython-310.pyc differ diff --git a/source/losses/__pycache__/depth_loss.cpython-310.pyc b/source/losses/__pycache__/depth_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19204644efea4e4945ae820f588b5ea4858b693b Binary files /dev/null and b/source/losses/__pycache__/depth_loss.cpython-310.pyc differ diff --git a/source/losses/__pycache__/sat_depth_loss.cpython-310.pyc b/source/losses/__pycache__/sat_depth_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ba988e4ec4584699a8fcfcb94a82a2b183e9556 Binary files /dev/null and b/source/losses/__pycache__/sat_depth_loss.cpython-310.pyc differ diff --git a/source/losses/density_tv_loss.py b/source/losses/density_tv_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4b58e7b29007dca46b0d2c267bd0b9b82799c714 --- /dev/null +++ b/source/losses/density_tv_loss.py @@ -0,0 +1,22 @@ +import torch + + +def compute_density_tv_loss(density, alpha): + """Compute total variation loss on density to regularise geometry. + + The density tensor is expected to contain initial and perturbed samples + concatenated along dim-1 (i.e. ``density.shape[1] == 2 * N``). The loss + encourages the perturbed density not to exceed the initial density by more + than *alpha*, which acts as a soft margin. + + Args: + density: Tensor of shape ``(B, 2*N, ...)``. + alpha: Scalar margin added to the initial density. + + Returns: + Scalar loss value. + """ + density_len = density.shape[1] // 2 + density_initial = density[:, :density_len] + density_perturbed = density[:, density_len:] + return torch.mean(torch.relu(density_perturbed - (density_initial + alpha))) diff --git a/source/losses/depth_loss.py b/source/losses/depth_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..85967dd52bd963f863c9df5bf2054cea520aa523 --- /dev/null +++ b/source/losses/depth_loss.py @@ -0,0 +1,164 @@ +# get from https://github.com/EPFL-VILAB/omnidata/blob/1af855042a05778d029d420b2a4bc1b9a0c09f30/omnidata_tools/torch/losses/midas_loss.py#L10 +# only change line 154-155, for our gt is 1/depth, not depth +# Based on https://gist.github.com/dvdhfnr/732c26b61a0e63a0abc8a5d769dbebd0 + +import torch +import torch.nn as nn +import numpy as np + + + +def masked_l1_loss(preds, target, mask_valid): + element_wise_loss = abs(preds - target) + element_wise_loss[~mask_valid] = 0 + return element_wise_loss.sum() / mask_valid.sum() + +def compute_scale_and_shift(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = torch.sum(mask * prediction * prediction, (1, 2)) + a_01 = torch.sum(mask * prediction, (1, 2)) + a_11 = torch.sum(mask, (1, 2)) + + # right hand side: b = [b_0, b_1] + b_0 = torch.sum(mask * prediction * target, (1, 2)) + b_1 = torch.sum(mask * target, (1, 2)) + + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = torch.zeros_like(b_0) + x_1 = torch.zeros_like(b_1) + + det = a_00 * a_11 - a_01 * a_01 + valid = det.nonzero() + + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / (det[valid] + 1e-6) + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / (det[valid] + 1e-6) + + return x_0, x_1 + + +def masked_shift_and_scale(depth_preds, depth_gt, mask_valid): + depth_preds_nan = depth_preds.clone() + depth_gt_nan = depth_gt.clone() + depth_preds_nan[~mask_valid] = np.nan + depth_gt_nan[~mask_valid] = np.nan + + mask_diff = mask_valid.view(mask_valid.size()[:2] + (-1,)).sum(-1, keepdims=True) + 1 + + t_gt = depth_gt_nan.view(depth_gt_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1) + t_gt[torch.isnan(t_gt)] = 0 + diff_gt = torch.abs(depth_gt - t_gt) + diff_gt[~mask_valid] = 0 + s_gt = (diff_gt.view(diff_gt.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1) + depth_gt_aligned = (depth_gt - t_gt) / (s_gt + 1e-6) + + + t_pred = depth_preds_nan.view(depth_preds_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1) + t_pred[torch.isnan(t_pred)] = 0 + diff_pred = torch.abs(depth_preds - t_pred) + diff_pred[~mask_valid] = 0 + s_pred = (diff_pred.view(diff_pred.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1) + depth_pred_aligned = (depth_preds - t_pred) / (s_pred + 1e-6) + + return depth_pred_aligned, depth_gt_aligned + + +def reduction_batch_based(image_loss, M): + # average of all valid pixels of the batch + + # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) + divisor = torch.sum(M) + + if divisor == 0: + return 0 + else: + return torch.sum(image_loss) / divisor + + +def reduction_image_based(image_loss, M): + # mean of average of valid pixels of an image + + # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) + valid = M.nonzero() + + image_loss[valid] = image_loss[valid] / M[valid] + + return torch.mean(image_loss) + + + +def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): + + M = torch.sum(mask, (1, 2)) + + diff = prediction - target + diff = torch.mul(mask, diff) + + grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) + mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) + grad_x = torch.mul(mask_x, grad_x) + + grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) + mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) + grad_y = torch.mul(mask_y, grad_y) + + image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) + + return reduction(image_loss, M) + + + +class SSIMAE(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, depth_preds, depth_gt, mask_valid): + depth_pred_aligned, depth_gt_aligned = masked_shift_and_scale(depth_preds, depth_gt, mask_valid) + ssi_mae_loss = masked_l1_loss(depth_pred_aligned, depth_gt_aligned, mask_valid) + return ssi_mae_loss + + +class GradientMatchingTerm(nn.Module): + def __init__(self, scales=4, reduction='batch-based'): + super().__init__() + + if reduction == 'batch-based': + self.__reduction = reduction_batch_based + else: + self.__reduction = reduction_image_based + + self.__scales = scales + + def forward(self, prediction, target, mask): + total = 0 + + for scale in range(self.__scales): + step = pow(2, scale) + + total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], + mask[:, ::step, ::step], reduction=self.__reduction) + + return total + + +class MidasLoss(nn.Module): + def __init__(self, alpha=0.1, scales=4, reduction='image-based'): + super().__init__() + + self.__ssi_mae_loss = SSIMAE() + self.__gradient_matching_term = GradientMatchingTerm(scales=scales, reduction=reduction) + self.__alpha = alpha + + def forward(self, prediction, target_inverse, mask): + prediction_inverse = 1 / (prediction+1e-6) + target = 1 / (target_inverse+1e-6) + ssi_loss = self.__ssi_mae_loss(prediction, target, mask) + target_inverse = target_inverse.squeeze(1) + prediction_inverse = prediction_inverse.squeeze(1) + mask = mask.squeeze(1) + + scale, shift = compute_scale_and_shift(prediction_inverse, target_inverse, mask) + prediction_ssi = scale.view(-1, 1, 1) * prediction_inverse + shift.view(-1, 1, 1) + reg_loss = self.__gradient_matching_term(prediction_ssi, target_inverse, mask) + if self.__alpha > 0: + total = ssi_loss + self.__alpha * reg_loss + return total, ssi_loss, reg_loss \ No newline at end of file diff --git a/source/losses/sat_depth_loss.py b/source/losses/sat_depth_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1e546f58b7604f88c9b8a8cdde1dce3eda81ce --- /dev/null +++ b/source/losses/sat_depth_loss.py @@ -0,0 +1,16 @@ +import torch + + +def compute_depth_loss(depth_loss_function, pred_depth, gt_depth): + """Compute scale-invariant depth loss for satellite view. + + Args: + depth_loss_function: A ``MidasLoss`` instance. + pred_depth: Predicted depth map. + gt_depth: Ground-truth depth map. + + Returns: + Scalar loss value (the first element returned by ``MidasLoss``). + """ + mask = torch.ones_like(gt_depth) > 0 + return depth_loss_function(pred_depth + 1, (gt_depth * 5 + 1), mask)[0] diff --git a/source/rendering/__init__.py b/source/rendering/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3294144aafcdd19c16c8b1daf490683cb6ffdd9 --- /dev/null +++ b/source/rendering/__init__.py @@ -0,0 +1,8 @@ +# python3.8 +"""Collects all rendering related modules.""" + +from .point_sampler import PointSampler +from .point_representer import PointRepresenter +from .point_integrator import PointIntegrator + +__all__ = ['PointSampler', 'PointRepresenter', 'PointIntegrator'] diff --git a/source/rendering/__pycache__/__init__.cpython-310.pyc b/source/rendering/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c88caeee9a56cb51cf7b4376af46279126b36348 Binary files /dev/null and b/source/rendering/__pycache__/__init__.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/__init__.cpython-38.pyc b/source/rendering/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..944cdd87651b870dbd7ea51c1977e343173b4e81 Binary files /dev/null and b/source/rendering/__pycache__/__init__.cpython-38.pyc differ diff --git a/source/rendering/__pycache__/aabb.cpython-310.pyc b/source/rendering/__pycache__/aabb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..080f74cafe9f54ab531a0657f62747735f2f48ec Binary files /dev/null and b/source/rendering/__pycache__/aabb.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/aabb.cpython-38.pyc b/source/rendering/__pycache__/aabb.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d12e94844cc24e6778e8d179eea53fde0496907 Binary files /dev/null and b/source/rendering/__pycache__/aabb.cpython-38.pyc differ diff --git a/source/rendering/__pycache__/mlp_model.cpython-310.pyc b/source/rendering/__pycache__/mlp_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff760c3497dda16becf6e2d9237c8e8e1d35df56 Binary files /dev/null and b/source/rendering/__pycache__/mlp_model.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/mlp_model.cpython-38.pyc b/source/rendering/__pycache__/mlp_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78819728ebd02624a518f100e915a55f90d36dff Binary files /dev/null and b/source/rendering/__pycache__/mlp_model.cpython-38.pyc differ diff --git a/source/rendering/__pycache__/pano2perspective.cpython-310.pyc b/source/rendering/__pycache__/pano2perspective.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21495604049acbf9a2364492e8727976623d6e26 Binary files /dev/null and b/source/rendering/__pycache__/pano2perspective.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/point_integrator.cpython-310.pyc b/source/rendering/__pycache__/point_integrator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dcdfe3d10e8166a3646c7b3a3547a8788553689 Binary files /dev/null and b/source/rendering/__pycache__/point_integrator.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/point_integrator.cpython-38.pyc b/source/rendering/__pycache__/point_integrator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da81090a22f49c09f32439b615829026eb69e6c0 Binary files /dev/null and b/source/rendering/__pycache__/point_integrator.cpython-38.pyc differ diff --git a/source/rendering/__pycache__/point_representer.cpython-310.pyc b/source/rendering/__pycache__/point_representer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d39705c58bda7a1989b04d19d96e6aa5a80d9fe2 Binary files /dev/null and b/source/rendering/__pycache__/point_representer.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/point_representer.cpython-38.pyc b/source/rendering/__pycache__/point_representer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4a1b58d5df6b0b89f2826a267f6a9e1cca587c7 Binary files /dev/null and b/source/rendering/__pycache__/point_representer.cpython-38.pyc differ diff --git a/source/rendering/__pycache__/point_sampler.cpython-310.pyc b/source/rendering/__pycache__/point_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4319296e25ac6caf8202ffcc99c77ea24567f307 Binary files /dev/null and b/source/rendering/__pycache__/point_sampler.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/point_sampler.cpython-38.pyc b/source/rendering/__pycache__/point_sampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c0296951d0923d939af157f2a57d4f3c64a3367 Binary files /dev/null and b/source/rendering/__pycache__/point_sampler.cpython-38.pyc differ diff --git a/source/rendering/__pycache__/sat2density_transform_eg3d.cpython-310.pyc b/source/rendering/__pycache__/sat2density_transform_eg3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1645e49a2ed1d4be6c9ddde62b249ba4cc8dbbf1 Binary files /dev/null and b/source/rendering/__pycache__/sat2density_transform_eg3d.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/sat2density_transform_eg3d.cpython-38.pyc b/source/rendering/__pycache__/sat2density_transform_eg3d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf11affe337a851e13f01d03753a886116c70c3d Binary files /dev/null and b/source/rendering/__pycache__/sat2density_transform_eg3d.cpython-38.pyc differ diff --git a/source/rendering/__pycache__/transform_perspective.cpython-310.pyc b/source/rendering/__pycache__/transform_perspective.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..582876c7e18185c4ec9c4cb36e4502b8a50d8798 Binary files /dev/null and b/source/rendering/__pycache__/transform_perspective.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/transform_perspective.cpython-38.pyc b/source/rendering/__pycache__/transform_perspective.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e331f7789794123ac258f190798ed396c3fb009c Binary files /dev/null and b/source/rendering/__pycache__/transform_perspective.cpython-38.pyc differ diff --git a/source/rendering/__pycache__/utils.cpython-310.pyc b/source/rendering/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb0d275eb07b4c98c1ae3ac817258104097faafe Binary files /dev/null and b/source/rendering/__pycache__/utils.cpython-310.pyc differ diff --git a/source/rendering/__pycache__/utils.cpython-38.pyc b/source/rendering/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b097b9b9e8b1e0980e6f334aebe349c6e6f1d55 Binary files /dev/null and b/source/rendering/__pycache__/utils.cpython-38.pyc differ diff --git a/source/rendering/aabb.py b/source/rendering/aabb.py new file mode 100644 index 0000000000000000000000000000000000000000..1211186cc2398e51eda3488ac47f603e465a4aa8 --- /dev/null +++ b/source/rendering/aabb.py @@ -0,0 +1,58 @@ +import numpy as np +import torch +from typing import Literal, Tuple +# borrowed from https://github.com/nerfstudio-project/nerfstudio +from einops import rearrange + +def intersect_aabb( + origins: torch.Tensor, + directions: torch.Tensor, + aabb: torch.Tensor = torch.tensor([-1., -1., -1., 1., 1., 1.]).float(), + max_bound: float = 1e10, + invalid_value: float = 1e10, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Implementation of ray intersection with AABB box + + Args: + origins: [N,3] tensor of 3d positions + directions: [N,3] tensor of normalized directions + aabb: [6] array of aabb box in the form of [x_min, y_min, z_min, x_max, y_max, z_max] + max_bound: Maximum value of t_max + invalid_value: Value to return in case of no intersection + + Returns: + t_min, t_max - two tensors of shapes N representing distance of intersection from the origin. + """ + # send aabb to origins's device + if len(origins.shape) == 4: + origins = rearrange(origins, 'b h w c -> (b h w) c') + directions = rearrange(directions, 'b h w c -> (b h w) c') + directions = directions.to(origins.device) + aabb = aabb.to(origins.device) + B = max(origins.size(0), directions.size(0)) + tx_min = (aabb[:3] - origins) / directions + tx_max = (aabb[3:] - origins) / directions + + t_min = torch.stack((tx_min, tx_max)).amin(dim=0) + t_max = torch.stack((tx_min, tx_max)).amax(dim=0) + + t_min = t_min.amax(dim=-1) + t_max = t_max.amin(dim=-1) + + t_min = torch.clamp(t_min, min=0, max=max_bound) + t_max = torch.clamp(t_max, min=0, max=max_bound) + + cond = t_max <= t_min + ### fix a bug may caused by version change? + invalid_value = torch.tensor(invalid_value).repeat(B).float().to(origins.device) + t_min = torch.where(cond, invalid_value, t_min) + t_max = torch.where(cond, invalid_value, t_max) + # tmax is what I need + return t_min, t_max + +def intersect_aabb_end(origin,dir,min=0,max=4): + t_max = intersect_aabb(origin,dir)[1] + assert torch.isnan(t_max).any() == False , "nan in t_max of intersect_aabb_end" + assert min < t_max.min() < max, "t_max out of range %s, min is %s, max is %s" % (t_max.min(), min, max) + return t_max diff --git a/source/rendering/mlp_model.py b/source/rendering/mlp_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d49d769bad4f5da24c68160800e60b0854c242 --- /dev/null +++ b/source/rendering/mlp_model.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +class SKYMLP(nn.Module): + r"""MLP converting ray directions to sky features.""" + def __init__(self, in_channels, style_dim, L=None,out_channels_c=3, + hidden_channels=256,is_pos_embedding = True): + super(SKYMLP, self).__init__() + self.is_pos_embedding = is_pos_embedding + self.L = L + self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False) + input_channel = in_channels+ 2*self.L*in_channels if is_pos_embedding else in_channels + self.fc1 = nn.Linear(input_channel, hidden_channels) + self.fc2 = nn.Linear(hidden_channels*2, hidden_channels) + # self.fc3 = nn.Linear(hidden_channels, hidden_channels) + # self.fc4 = nn.Linear(hidden_channels, hidden_channels) + # self.fc5 = nn.Linear(hidden_channels, hidden_channels) + self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + def positional_encoding(self,input): # [B,...,N] + shape = input.shape + freq = 2**torch.arange(self.L,dtype=torch.float32).to(input.device)*np.pi # [L] + spectrum = input[...,None]*freq # [B,...,N,L] + sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] + input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] + input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] + return input_enc + def forward(self, x, z): + r"""Forward network + Args: + x (... x in_channels tensor): Ray direction embeddings. + z (... x style_dim tensor): Style codes. + """ + x = torch.cat([x,self.positional_encoding(x)],dim=-1) if self.is_pos_embedding else x + z = self.fc_z_a(z) + assert len(x.shape) == 4 + (H,W )= x.shape[1:3] + z = repeat(z,'b c -> b h w c',h=H,w=W) + # z = z.repeat(1,H,W,1) + # y = self.act(self.fc1(x) + z) + # # cat + y = self.act(torch.cat([self.fc1(x),z],dim=-1)) + y = self.act(self.fc2(y)) + c = self.fc_out_c(y) + c = torch.sigmoid(c) + return c + +class MLPNetwork2(nn.Module): + """Defines fully-connected layer head in EG3D.""" + + def __init__(self, input_dim, hidden_dim, output_dim,style_dim=270): + super().__init__() + + self.net0 = nn.Linear(input_dim, hidden_dim) + self.net0_act = nn.Softplus() + self.net1_feature = nn.Linear(hidden_dim, hidden_dim//2) + self.net1_density = nn.Linear(hidden_dim, 1) + self.style_dim = style_dim + self.style_squ = nn.Linear(self.style_dim,hidden_dim//2) + self.grd_color_convert = nn.Linear(hidden_dim, output_dim) + + def forward(self, point_features, style=None, only_density=False): + + N, M, C = point_features.shape + point_features = point_features.view(N * M, C) + y = self.net0(point_features) + y = self.net0_act(y) + density = self.net1_density(y).view(N, M, -1) + result = {} + result['density'] = density + if only_density: + return result + + color = self.net1_feature(y).view(N, M, -1) + if style is None: + style = repeat(torch.zeros([self.style_dim]), 'd -> n m d', n=N,m=M).float().to(point_features.device) + style = self.style_squ(style) + + if len(style.shape) == 2: + style = repeat(style, 'n d -> n m d', m=M).float().to(point_features.device) + combine_color_style = torch.cat([color, style], dim=-1) + color = self.grd_color_convert(combine_color_style) + + color = torch.sigmoid(color) + result['color'] = color + return result diff --git a/source/rendering/pano2perspective.py b/source/rendering/pano2perspective.py new file mode 100644 index 0000000000000000000000000000000000000000..8ccdcb5acadc4fdce8716a1a98e782575424c094 --- /dev/null +++ b/source/rendering/pano2perspective.py @@ -0,0 +1,68 @@ +import cv2 +import numpy as np + +def xyz2lonlat(xyz): + atan2 = np.arctan2 + asin = np.arcsin + + norm = np.linalg.norm(xyz, axis=-1, keepdims=True) + xyz_norm = xyz / norm + x = xyz_norm[..., 0:1] + y = xyz_norm[..., 1:2] + z = xyz_norm[..., 2:] + + lon = atan2(x, z) + lat = asin(y) + lst = [lon, lat] + + out = np.concatenate(lst, axis=-1) + return out + +def lonlat2XY(lonlat, shape): + X = (lonlat[..., 0:1] / (2 * np.pi) + 0.5) * (shape[1] - 1) + Y = (lonlat[..., 1:] / (np.pi) + 0.5) * (shape[0] - 1) + lst = [X, Y] + out = np.concatenate(lst, axis=-1) + + return out + +def GetPerspective(image, intrinsic_four, yaw, pitch, height, width): + # + # for yaw, higher is looking right, lower is looking left + # for pitch, higher is looking down, lower is looking up + # + if not type(image) == list: + image = [image] + THETA = yaw + PHI = -pitch + fx,fy,cx,cy = intrinsic_four + K = np.array([ + [fx, 0, cx], + [0, fy, cy], + [0, 0, 1], + ], np.float32) + K_inv = np.linalg.inv(K) + + x = np.arange(width) + y = np.arange(height) + x, y = np.meshgrid(x, y) + z = np.ones_like(x) + xyz = np.concatenate([x[..., None], y[..., None], z[..., None]], axis=-1) + xyz = xyz @ K_inv.T + + y_axis = np.array([0.0, 1.0, 0.0], np.float32) + x_axis = np.array([1.0, 0.0, 0.0], np.float32) + R1, _ = cv2.Rodrigues(y_axis * np.radians(THETA)) + R2, _ = cv2.Rodrigues(np.dot(R1, x_axis) * np.radians(PHI)) + R = R2 @ R1 + xyz = xyz @ R.T + lonlat = xyz2lonlat(xyz) + XY = lonlat2XY(lonlat, shape=image[0].shape).astype(np.float32) + result = [] + for img in image: + persp = cv2.remap(img, XY[..., 0], XY[..., 1], cv2.INTER_CUBIC, borderMode=cv2.BORDER_WRAP) + persp = np.clip(persp, 0, 255).astype(np.uint8) + result.append(persp) + if len(result) == 1: + return result[0] + return result diff --git a/source/rendering/point_integrator.py b/source/rendering/point_integrator.py new file mode 100644 index 0000000000000000000000000000000000000000..52af9300e3191d61bc3352b222cde757986dff44 --- /dev/null +++ b/source/rendering/point_integrator.py @@ -0,0 +1,263 @@ +# python3.8 +"""Contains the function of ray marching. + +Ray marching focuses on a single marching ray, which goes through a collection +of particles (points). Each point in the 3D space is represented by emitted +color and volume density. The final color to appear for each ray can be obtained +by accumulating the per-point color regarding the per-point density. + +Ray marching is an important step for Neural Radiance Field (NeRF). + +Paper: https://arxiv.org/pdf/2003.08934.pdf +""" + +import torch +import torch.nn.functional as F + +__all__ = ['PointIntegrator'] + +_DENSITY_CLAMP_MODES = ['relu', 'softplus', 'mipnerf'] +_COLOR_CLAMP_MODES = ['none', 'widen_sigmoid'] + +EPS = 1e-3 + + +class PointIntegrator(torch.nn.Module): + """Defines the class to accumulate points along each ray. + + This class implements the `forward()` function for ray marching, which + includes the following steps: + + 1. Get the color and density of the points for each ray. + 2. Get alpha values for alpha compositing. + 3. Get accumulated transmittances. + 4. Get composite color and density with weighted sum (i.e., integration). + + More details can be found in Section 4 of paper + + https://arxiv.org/pdf/2003.08934.pdf + """ + + def __init__(self, + use_mid_point=True, + use_dist=True, + max_radial_dist=1e10, + density_noise_std=0.0, + density_clamp_mode='relu', + color_clamp_mode='none', + normalize_color=False, + delta_modulate_scalar=1.0, + use_white_background=False, + scale_color=True, + normalize_radial_dist=False, + clip_radial_dist=False): + """Initializes hyper-parameters for ray marching. + + Args: + use_mid_point: Whether to use the middle point between two adjacent + points on each ray for accumulation. Defaults to `True`. + use_dist: Whether to consider the distance between two adjacent + points on each ray for accumulation. If set as `False`, the + distance between two adjacent points is constantly set as `1`. + Defaults to `True`. + max_radial_dist: The maximum radial distance between a particular + point to the camera. This argument is used to prevent the ray + from going too far away. Defaults to `1e10`. + density_noise_std: Standard deviation of the gaussian noise added to + densities. + density_clamp_mode: Mode of clamping densities. Defaults to `relu`. + color_clamp_mode: Mode of clamping colors. Defaults to `none`. + normalize_color: Whether to normalize the output composite color per + ray. Defaults to `False`. + delta_modulate_scalar: Scalar value to modulate delta of radial + distance. + use_white_background: Whether to use white background. Defaults to + `False`. + scale_color: Whether to scale the output composite color to range + (-1, 1). Defaults to `True`. + normalize_radial_dist: Whether to normalize the output composite + radial distance per ray. Defaults to `True`. + clip_radial_dist: Whether to clip the output composite radial + distance. Defaults to `True`. + """ + super().__init__() + self.use_mid_point = use_mid_point + self.use_dist = use_dist + self.max_radial_dist = max_radial_dist + self.density_noise_std = density_noise_std + self.density_clamp_mode = density_clamp_mode + self.color_clamp_mode = color_clamp_mode + self.normalize_color = normalize_color + self.delta_modulate_scalar = delta_modulate_scalar + self.use_white_background = use_white_background + self.scale_color = scale_color + self.normalize_radial_dist = normalize_radial_dist + self.clip_radial_dist = clip_radial_dist + + def forward(self, colors, densities, radii, **kwargs): + """Integrates points along each ray. + + For simplicity, we define the following notations: + + `N` denotes batch size. + `R` denotes the number of rays, which usually equals `H * W`. + `K` denotes the number of points on each ray. + + Args: + colors: Per-point emitted color, with shape [N, R, K, C]. Here `C` + denotes the number of color channels. Note that, the color can + be represented by gray value (`C = 1`), RGB values (`C = 3`), or + a feature vector (such as `C = 64`). + densities: Per-point volume density, with shape [N, R, K, 1]. Here, + the density can be roughly interpreted as how likely a ray will + be blocked by this point. + radii: Per-point radial distance, with shape [N, R, K, 1]. Here, the + distance is measured by treating the camera as the origin. + **kwargs: Additional keyword arguments to override the variables + initialized in `__init__()`. + + Returns: + A dictionary, containing + - `composite_color`: The final per-ray composite color (or + color feature), with shape [N, R, C]. + - `composite_radial_dist`: The final per-ray composite radial + distance, with shape [N, R, 1]. + - `weights`: Per-point weight for integral, with shape + [N, R, K, 1]. + - `T_end`: The accumulated transmittance along the ray from + the start point `p_s` to the end point `p_e` in the + foreground scene. This can be interpreted as the probability + of the ray travelling from `p_s` to `p_e` without hitting + any other particles in the foreground scene. This variable + is with shape [N, R, 1]. + """ + # Parse arguments. + use_mid_point = kwargs.get('use_mid_point', self.use_mid_point) + use_dist = kwargs.get('use_dist', self.use_dist) + max_radial_dist = kwargs.get('max_radial_dist', self.max_radial_dist) + density_noise_std = kwargs.get('density_noise_std', + self.density_noise_std) + density_clamp_mode = kwargs.get( + 'density_clamp_mode', self.density_clamp_mode) + color_clamp_mode = kwargs.get('color_clamp_mode', self.color_clamp_mode) + normalize_color = kwargs.get('normalize_color', self.normalize_color) + delta_modulate_scalar = kwargs.get( + 'delta_modulate_scalar', self.delta_modulate_scalar) + use_white_background = kwargs.get( + 'use_white_background', self.use_white_background) + scale_color = kwargs.get('scale_color', self.scale_color) + normalize_radial_dist = kwargs.get( + 'normalize_radial_dist', self.normalize_radial_dist) + clip_radial_dist = kwargs.get('clip_radial_dist', self.clip_radial_dist) + + # Check inputs. + assert colors.ndim == 4 + N, R, K, _ = colors.shape + assert densities.shape == (N, R, K, 1) + assert radii.shape == (N, R, K, 1) + density_clamp_mode = density_clamp_mode.lower() + if density_clamp_mode not in _DENSITY_CLAMP_MODES: + raise ValueError(f'Invalid clamp mode: `{density_clamp_mode}`!\n' + f'Modes allowed: {_DENSITY_CLAMP_MODES}.') + color_clamp_mode = color_clamp_mode.lower() + if color_clamp_mode not in _COLOR_CLAMP_MODES: + raise ValueError(f'Invalid clamp mode: `{color_clamp_mode}`!\n' + f'Modes allowed: {_COLOR_CLAMP_MODES}.') + + # Compute distances between adjacent points on each ray. Such a distance + # is termed as `delta` in the paper (Eq. (3)). + deltas = radii[:, :, 1:, :] - radii[:, :, :-1, :] # [N, R, K-1, 1] + + if delta_modulate_scalar != 1: + deltas = torch.ones_like(deltas) * delta_modulate_scalar + + if use_mid_point: # Using K-1 points on each ray. + colors = (colors[:, :, :-1, :] + colors[:, :, 1:, :]) / 2 + densities = (densities[:, :, :-1, :] + densities[:, :, 1:, :]) / 2 + radii = (radii[:, :, :-1, :] + radii[:, :, 1:, :]) / 2 + else: # Using K points on each ray. + # Append a maximum distance to make sure all points have reference. + delta_last = max_radial_dist * torch.ones_like(deltas[:, :, :1, :]) + deltas = torch.cat([deltas, delta_last], dim=2) # [N, R, K, 1] + + ray_dirs = kwargs.get('ray_dirs') + if ray_dirs is not None: # [N, R, 3] + assert ray_dirs.shape == (N, R, 3) + ray_dirs = ray_dirs.unsqueeze(-1) # [N, R, 3, 1] + deltas = deltas * torch.norm(ray_dirs, dim=-2, + keepdim=True) # [N, R, K, 1] + + if not use_dist: + deltas[:] = 1 + + if 'bg_index' in kwargs: + bg_index = F.one_hot(kwargs['bg_index'].squeeze(-1), + num_classes=deltas.shape[-2]).to(torch.bool) + bg_index = bg_index.unsqueeze(-1) + deltas[bg_index] = max_radial_dist + + if density_noise_std > 0: + densities = densities + density_noise_std * torch.randn_like( + densities) + + if density_clamp_mode == 'relu': + densities = F.relu(densities + 3) + elif density_clamp_mode == 'softplus': + densities = F.softplus(densities) + elif density_clamp_mode == 'mipnerf': + densities = F.softplus(densities - 1) + else: + raise ValueError(f'Not implemented clamping mode: ' + f'`{density_clamp_mode}`!\n') + + if color_clamp_mode == 'widen_sigmoid': + colors = torch.sigmoid(colors) * (1 + 2 * EPS) - EPS + + # Compute per-point alpha values. See Eq. (3) in the paper. + alphas = 1 - torch.exp(- deltas * densities) + if not use_mid_point and max_radial_dist > 0: + alphas[:, :, -1, :] = 1 + + if 'is_valid' in kwargs: + alphas = alphas * kwargs['is_valid'] + + # Compute per-point accumulated transmittance. See Eq. (3) in the paper. + # Here, we shift `alpha` forward by one index, because the transmittance + # of each point is only related to its previous points, excluding + # itself. + alphas_shifted = torch.cat( + [torch.ones_like(alphas[:, :, :1, :]), 1 - alphas + 1e-10], dim=2) + T = torch.cumprod(alphas_shifted, dim=2)[:, :, :-1, :] # Transmittance. + + # Compute per-point integral weights. + weights = alphas * T + weights_sum = weights.sum(dim=2) + + # Get per-ray color. + composite_color = torch.sum(weights * colors, dim=2) + if normalize_color: + composite_color = composite_color / weights_sum + if use_white_background: + composite_color = composite_color + 1 - weights_sum + if scale_color: + composite_color = composite_color * 2 - 1 + + # Get per-ray radial distance. + composite_radial_dist = torch.sum(weights * radii, dim=2) + if normalize_radial_dist: + composite_radial_dist = composite_radial_dist / weights_sum + if clip_radial_dist: + composite_radial_dist = torch.nan_to_num( + composite_radial_dist, float('inf')) + composite_radial_dist = torch.clip( + composite_radial_dist, torch.min(radii), torch.max(radii)) + + results = { + 'composite_color': composite_color, + 'composite_radial_dist': composite_radial_dist, + 'weight': weights, + 'T_end': T[:, :, -1, :], + 'opacity': weights_sum, + } + + return results diff --git a/source/rendering/point_representer.py b/source/rendering/point_representer.py new file mode 100644 index 0000000000000000000000000000000000000000..9081c0dfc5270fa3b9914cad75b48f0ab41c56aa --- /dev/null +++ b/source/rendering/point_representer.py @@ -0,0 +1,673 @@ +# python3.8 +"""Contains the functions to represent a point in 3D space. + +Typically, a point can be represented by its 3D coordinates, by retrieving from +a feature volume, or by combining triplane features. + +Paper (coordinate): https://arxiv.org/pdf/2003.08934.pdf +Paper (feature volume): https://arxiv.org/pdf/2112.10759.pdf +Paper (triplane): https://arxiv.org/pdf/2112.07945.pdf +""" + +from einops import rearrange +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['PointRepresenter'] + +_REPRESENTATION_TYPES = ['coordinate', 'volume', 'triplane', 'hybrid', 'mpi', 'oneplane', 'oneplane_multi'] + + +class PointRepresenter(nn.Module): + """Defines the class to get per-point representation. + + This class implements the `forward()` function to get the representation + based on the per-point 3D coordinates and the reference representation (such + as a feature volume or triplane features). + """ + + def __init__(self, + representation_type='coordinate', + triplane_axes=None, + mpi_levels=None, + coordinate_scale=None, + bound=None, + return_eikonal=False, + ): + """Initializes hyper-parameters for getting point representations. + + NOTE: + + When using triplane representation, the three planes are defaulted as + follows: + + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[1, 0, 0], [0, 0, 1], [0, 1, 0]], + [[0, 0, 1], [0, 1, 0], [1, 0, 0]] + ] + + where for each plane, the first two rows stand for the plane axes while + the third row stands for the plane normal. + + Args: + representation_type: Type of representation used to describe a point + in the 3D space. Defaults to `coordinate`. + coordinate_scale: Scale factor to normalize coordinates. + Defaults to `None`. + bound: Bound used to normalize coordinates, with shape [1, 2, 3]. + Defaults to `None`. + return_eikonal: If the eikonal loss is to be used, we utilize the + function `grid_sample_customized()` instead of `F.grid_sample()` + to avoid errors in computing the second derivative. + + Note that only one of the above two parameters used for normalizing + coordinates can be available. + """ + super().__init__() + self.coordinate_scale = None + if (coordinate_scale is not None) and (coordinate_scale > 0): + self.coordinate_scale = coordinate_scale + if bound is not None: + self.register_buffer('bound', bound) + else: + self.bound = None + self.return_eikonal = return_eikonal + + representation_type = representation_type.lower() + if representation_type not in _REPRESENTATION_TYPES: + raise ValueError(f'Invalid representation type: ' + f'`{representation_type}`!\n' + f'Types allowed: {_REPRESENTATION_TYPES}.') + + self.representation_type = representation_type + if self.representation_type in ['coordinate', 'volume']: + pass + elif self.representation_type in ['triplane', 'hybrid']: + if triplane_axes is None: + self.register_buffer( + 'triplane_axes', + torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[1, 0, 0], [0, 0, 1], [0, 1, 0]], + [[0, 0, 1], [0, 1, 0], [1, 0, 0]]], + dtype=torch.float32)) + else: + self.register_buffer('triplane_axes', triplane_axes) + elif self.representation_type in ['oneplane', 'oneplane_multi']: + self.register_buffer( + 'oneplane_axes', + torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], + dtype=torch.float32)) + elif self.representation_type == 'mpi': + self.register_buffer('mpi_levels', mpi_levels) + else: + raise NotImplementedError(f'Not implemented representation type: ' + f'`{self.representation_type}`!\n') + + def forward(self, + points, + ref_representation=None, + align_corners=False): + """Gets per-point representation based on its coordinates. + + For simplicity, we define the following notations: + + `N` denotes batch size. + `R` denotes the number of rays, which usually equals `H * W`. + `K` denotes the number of points on each ray. + `C` denotes the dimension of per-point representation. + + Args: + points: Per-point 3D coordinates, with shape [N, R * K, 3]. + ref_representation: The reference representation, depending on the + representation type used. For example, this field will be + ignored if `self.representation_type` is set as `coordinate`, + a feature volume is expected if `self.representation_type` is + set as `volume`, while triplane features are expected if + `self.representation_type` is set as `triplane`. Defaults to + `None`. + + Returns: + Per-point representation, with shape [N, R * K, C]. + """ + if self.representation_type == 'coordinate': + return points + if self.representation_type == 'mpi': + return retrieve_from_mpi(points=points, # [N, R, K, 3] + isosurfaces=ref_representation, + levels=self.mpi_levels) + + # Normalize point coordinates to the desired range, typically [-1, 1]. + if self.coordinate_scale is not None: + normalized_points = (2 / self.coordinate_scale) * points + elif self.bound is not None: + normalized_points = (points - self.bound[:, :1]) / ( + self.bound[:, 1:] - self.bound[:, :1]) # To range [0, 1]. + normalized_points = 2 * normalized_points - 1 # To range [-1, 1]. + else: + normalized_points = points + + if self.representation_type == 'volume': + return retrieve_from_volume( + coordinates=normalized_points, + volume=ref_representation) + if self.representation_type == 'triplane': + return retrieve_from_planes( + plane_axes=self.triplane_axes.to(points.device), + plane_features=ref_representation, + coordinates=normalized_points, + align_corners=align_corners, + return_eikonal=self.return_eikonal, + ) + if self.representation_type == 'oneplane': + return retrieve_from_one_plane( + plane_axes=self.oneplane_axes.to(points.device), + plane_features=ref_representation, + coordinates=normalized_points, + align_corners=align_corners, + mean=False + ) + if self.representation_type == 'oneplane_multi': + return retrieve_from_one_plane( + plane_axes=self.oneplane_axes.to(points.device), + plane_features=ref_representation, + coordinates=normalized_points, + align_corners=align_corners, + mean=True + ) + if self.representation_type == 'hybrid': + assert (isinstance(ref_representation, list) + or isinstance(ref_representation, tuple)) + triplane = ref_representation[0] + feature_volume = ref_representation[1] + point_features_triplane = retrieve_from_planes( + plane_axes=self.triplane_axes.to(points.device), + plane_features=triplane, + coordinates=normalized_points, + align_corners=align_corners, + return_eikonal=self.return_eikonal) + point_features_volume = retrieve_from_volume( + coordinates=normalized_points, + volume=feature_volume) + point_features = torch.cat( + [point_features_volume, point_features_triplane], dim=-1) + return point_features + + raise NotImplementedError(f'Not implemented representation type: ' + f'`{self.representation_type}`!\n') + + +def grid_sample_3d(volume, coordinates): + """Performs grid sample in 3D space. Given 3D point coordinates, sample + values from the volume. Note that this function is similar to function + `torch.nn.functional.grid_sample()` in the case of 5-D inputs. + + Args: + volume: The given volume, with shape [N, C, D, H, W]. + coordinates: Input 3D point coordinates, with shape + [N, 1, 1, d * h * w, 3]. + + Returns: + sampled_vals: Sampled values, with shape [N, C, d * h * w, 1, 1]. + """ + N, C, ID, IH, IW = volume.shape + _, D, H, W, _ = coordinates.shape + + ix = coordinates[..., 0] + iy = coordinates[..., 1] + iz = coordinates[..., 2] + + ix = ((ix + 1) / 2) * (IW - 1) + iy = ((iy + 1) / 2) * (IH - 1) + iz = ((iz + 1) / 2) * (ID - 1) + with torch.no_grad(): + ix_tnw = torch.floor(ix) + iy_tnw = torch.floor(iy) + iz_tnw = torch.floor(iz) + + ix_tne = ix_tnw + 1 + iy_tne = iy_tnw + iz_tne = iz_tnw + + ix_tsw = ix_tnw + iy_tsw = iy_tnw + 1 + iz_tsw = iz_tnw + + ix_tse = ix_tnw + 1 + iy_tse = iy_tnw + 1 + iz_tse = iz_tnw + + ix_bnw = ix_tnw + iy_bnw = iy_tnw + iz_bnw = iz_tnw + 1 + + ix_bne = ix_tnw + 1 + iy_bne = iy_tnw + iz_bne = iz_tnw + 1 + + ix_bsw = ix_tnw + iy_bsw = iy_tnw + 1 + iz_bsw = iz_tnw + 1 + + ix_bse = ix_tnw + 1 + iy_bse = iy_tnw + 1 + iz_bse = iz_tnw + 1 + + tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz) + tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz) + tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz) + tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz) + bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse) + bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw) + bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne) + bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw) + + with torch.no_grad(): + torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw) + torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw) + torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw) + + torch.clamp(ix_tne, 0, IW - 1, out=ix_tne) + torch.clamp(iy_tne, 0, IH - 1, out=iy_tne) + torch.clamp(iz_tne, 0, ID - 1, out=iz_tne) + + torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw) + torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw) + torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw) + + torch.clamp(ix_tse, 0, IW - 1, out=ix_tse) + torch.clamp(iy_tse, 0, IH - 1, out=iy_tse) + torch.clamp(iz_tse, 0, ID - 1, out=iz_tse) + + torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) + torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) + torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) + + torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) + torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) + torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) + + torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) + torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) + torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) + + torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) + torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) + torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) + + volume = volume.view(N, C, ID * IH * IW) + + tnw_val = torch.gather(volume, 2, + (iz_tnw * IW * IH + iy_tnw * IW + + ix_tnw).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + tne_val = torch.gather(volume, 2, + (iz_tne * IW * IH + iy_tne * IW + + ix_tne).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + tsw_val = torch.gather(volume, 2, + (iz_tsw * IW * IH + iy_tsw * IW + + ix_tsw).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + tse_val = torch.gather(volume, 2, + (iz_tse * IW * IH + iy_tse * IW + + ix_tse).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + bnw_val = torch.gather(volume, 2, + (iz_bnw * IW * IH + iy_bnw * IW + + ix_bnw).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + bne_val = torch.gather(volume, 2, + (iz_bne * IW * IH + iy_bne * IW + + ix_bne).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + bsw_val = torch.gather(volume, 2, + (iz_bsw * IW * IH + iy_bsw * IW + + ix_bsw).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + bse_val = torch.gather(volume, 2, + (iz_bse * IW * IH + iy_bse * IW + + ix_bse).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + + sampled_vals = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) + + tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) + + tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) + + tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) + + bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W)) + + return sampled_vals + + +def grid_sample_customized(input, grid): + """Customized `grid_sample()` operation. + + Since the original PyTorch `grid_sample()` operator does not support second + derivative computation during the backward pass, we customize this operator. + + Args: + input: Input tensor. + grid: Flow-field. + + Returns: + output: Output Tensor. + """ + N, C, IH, IW = input.shape + _, H, W, _ = grid.shape + + if torch.any(torch.isnan(grid)): + grid = torch.ones_like(grid) + print('nan') + + ix = grid[..., 0] + iy = grid[..., 1] + + ix = ((ix + 1) / 2) * (IW - 1) + iy = ((iy + 1) / 2) * (IH - 1) + with torch.no_grad(): + ix_nw = torch.floor(ix) + iy_nw = torch.floor(iy) + ix_ne = ix_nw + 1 + iy_ne = iy_nw + ix_sw = ix_nw + iy_sw = iy_nw + 1 + ix_se = ix_nw + 1 + iy_se = iy_nw + 1 + + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) + + with torch.no_grad(): + torch.clamp(ix_nw, 0, IW - 1, out=ix_nw) + torch.clamp(iy_nw, 0, IH - 1, out=iy_nw) + + torch.clamp(ix_ne, 0, IW - 1, out=ix_ne) + torch.clamp(iy_ne, 0, IH - 1, out=iy_ne) + + torch.clamp(ix_sw, 0, IW - 1, out=ix_sw) + torch.clamp(iy_sw, 0, IH - 1, out=iy_sw) + + torch.clamp(ix_se, 0, IW - 1, out=ix_se) + torch.clamp(iy_se, 0, IH - 1, out=iy_se) + + input = input.view(N, C, IH * IW) + + nw_val = torch.gather(input, 2, (iy_nw * IW + ix_nw).long().view( + N, 1, H * W).repeat(1, C, 1)) + ne_val = torch.gather(input, 2, (iy_ne * IW + ix_ne).long().view( + N, 1, H * W).repeat(1, C, 1)) + sw_val = torch.gather(input, 2, (iy_sw * IW + ix_sw).long().view( + N, 1, H * W).repeat(1, C, 1)) + se_val = torch.gather(input, 2, (iy_se * IW + ix_se).long().view( + N, 1, H * W).repeat(1, C, 1)) + + output = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + + se_val.view(N, C, H, W) * se.view(N, 1, H, W)) + + return output + + +def retrieve_from_volume(coordinates, volume): + """Samples point features from feature volume. + + Args: + coordinates: Coordinate of input 3D points, with shape [N, R * K, 3]. + volume: Feature volume, with shape [N, C, D, H, W]. + + Returns: + output_features: Output sampled point features, with shape + [N, R * K, C]. + """ + grid_coords = coordinates[:, None, None] # [N, 1, 1, R * K, 3] + output_features = grid_sample_3d(volume, grid_coords) # [N, C, R * K, 1, 1] + output_features = output_features[:, :, 0, 0] # [N, C, R * K] + output_features = output_features.permute(0, 2, 1) # [N, R * K, C] + + return output_features + + +def project_points_onto_planes(points, planes): + """ + Projects 3D points onto a batch of 2D planes. + + To project a 3D point `P` onto a 2D plane defined by a normal vector `n` + and a point `Q` that lies on the plane, one can use the following formula: + + P_proj = P - dot(P-Q, n) * n / dot(n, n) + + where: + `P_proj` is the projected point on the plane; + `dot()` is the dot product. + + And `Q` can be chosen as the origin (0, 0, 0) of the coordinate system. + Meanwhile, if n` is a normalized vector, then the projection formula is + simplified as: + + P_proj = P - dot(P, n) * n + + Args: + points: Point coordinates, with shape [N, M, 3], where `M` is the + number of points in each batch and equals `R * K`. + planes: Planes, with shape [n_planes, 3, 3], where `n_planes` + is the number of planes. Here, a plane is represented by two vector + axes and one normal vector. For instance, if a plane is + represented by: + `[[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]`, + which means that its axes are the third and second axes of the + coordinate system, and its normal vector is `[1, 0, 0]`. + + Returns: + projections: Projections, with shape [N * n_planes, R * K, 2]. + """ + plane_normals = planes[:, 2] + N, M, _ = points.shape # `M` equals `R * K`. + n_planes, _ = plane_normals.shape + + # Normalize the normals to unit vectors. + plane_normals = F.normalize(plane_normals, dim=1) + + # Unsqueeze, expand and reshape tensors. + points = points.unsqueeze(1).expand( + -1, n_planes, -1, -1).reshape(N * n_planes, M, + 3) # [N * n_planes, R * K , 3] + plane_normals = plane_normals.unsqueeze(0).expand(N, -1, -1).reshape( + N * n_planes, 3) # [N * n_planes, 3] + plane_normals = plane_normals.unsqueeze(1).expand( + -1, M, -1) # [N * n_planes, R * K, 3] + + # Compute the projections. + projections = points - torch.sum(points * plane_normals, + dim=-1).unsqueeze(-1) * plane_normals + + # Extract the projection values from different planes. + plane_axes = planes.unsqueeze(0).expand(N, -1, -1, -1).reshape( + N * n_planes, 3, 3) + projections = torch.bmm(projections, plane_axes.permute(0, 2, 1))[..., :2] + return projections + + +def retrieve_from_planes(plane_axes, + plane_features, + coordinates, + mode='bilinear', + align_corners=False, + return_eikonal=False, + ): + """Samples point features from triplane. Borrowed from + + https://github.com/NVlabs/eg3d/blob/main/eg3d/training/volumetric_rendering/renderer.py + + Args: + plane_axes: Axes of triplane, with shape [n_planes, 3, 3]. + plane_features: Triplane features, with shape [N, n_planes, C, H, W]. + coordinates: Coordinate of input 3D points, with shape [N, R * K, 3]. + mode: Interpolation mode. + + Returns: + output_features: Output sampled point features, with shape + [N, R * K, C]. + """ + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape # `M` equals `R * K`. + # plane_features = plane_features.view(N * n_planes, C, H, W) + plane_features = rearrange(plane_features, 'N n_planes c h w -> (N n_planes) c h w') + + + projected_coordinates = project_points_onto_planes( + coordinates, + plane_axes).unsqueeze(1) # [N * n_planes, 1, R * K, 2] + if return_eikonal: + output_features = grid_sample_customized( + plane_features, + projected_coordinates.float()) # [N * n_planes, C, 1, R * K] + else: + output_features = F.grid_sample( + plane_features, + projected_coordinates.float(), + mode=mode, + padding_mode='zeros', + align_corners=align_corners) # [N * n_planes, C, 1, R * K] + output_features = output_features.permute( + 0, 3, 2, 1) # [N * n_planes, R * K, 1, C] + output_features = output_features.reshape(N, n_planes, M, + C) # [N, 3, R * K, C] + output_features = output_features.mean(1) # [N, R * K, C] + + return output_features + +def retrieve_from_one_plane(plane_axes, + plane_features, + coordinates, + mode='bilinear', + align_corners=False, + return_eikonal=False, + mean = False, + ): + """Samples point features from triplane. Borrowed from + + https://github.com/NVlabs/eg3d/blob/main/eg3d/training/volumetric_rendering/renderer.py + + Args: + plane_axes: Axes of triplane, with shape [n_planes, 3, 3]. + plane_features: Triplane features, with shape [N, n_planes, C, H, W]. + coordinates: Coordinate of input 3D points, with shape [N, R * K, 3]. + mode: Interpolation mode. + + Returns: + output_features: Output sampled point features, with shape + [N, R * K, C]. + """ + assert type(plane_features) == list + N, num_plane, C, H, W = plane_features[0].shape + _, M, _ = coordinates.shape # `M` equals `R * K`. + one_plane_features = plane_features[0].view(N * 1, C, H, W) + + line_features = plane_features[1] # [N, C, L] + line_features_4d = line_features.unsqueeze(-1) # [N, C, L, 1] + z_point = coordinates[..., -1:] # [N, R * K, 1] + z_point = z_point.unsqueeze(1) # [N, 1, R * K, 1] + y_fixed = torch.zeros_like(z_point) + coordinates_z_cat = torch.cat([y_fixed,z_point], dim=-1) # [8, 1, 786432, 2] + # coordinates_z_cat = torch.cat([y_fixed,z_point], dim=-1) # confused [8, 1, 786432, 2] + + z_features = F.grid_sample( + line_features_4d, + coordinates_z_cat.float(), + mode=mode, + padding_mode='zeros', + align_corners=align_corners) # [N, C, 1, R * K] + z_features = z_features.permute(0, 2, 3, 1) # [N, R * K, 1, C] + # z_features = z_features.reshape(N, 1, M, C) # [N, 1, R * K, C] + + projected_coordinates = project_points_onto_planes( + coordinates, + plane_axes).unsqueeze(1) # [N * n_planes, 1, R * K, 2] + # the last channel exchange the order of the coordinates + if return_eikonal: + output_features = grid_sample_customized( + one_plane_features, + projected_coordinates.float()) # [N * n_planes, C, 1, R * K] + else: + output_features = F.grid_sample( + one_plane_features, + projected_coordinates.float(), + mode=mode, + padding_mode='zeros', + align_corners=align_corners) # [N * n_planes, C, 1, R * K] + output_features = output_features.permute( + 0, 3, 2, 1) # [N * n_planes, R * K, 1, C] + output_features = output_features.reshape(N,1, M, + C) # [N, 1, R * K, C] + if mean ==False: + output_features = torch.cat([output_features,z_features], dim=-1) + else: + output_features = torch.cat([output_features,z_features], dim=1) + + return output_features.mean(1) + + +def retrieve_from_mpi(points, isosurfaces, levels): + """Get intersections between camera rays and levels. + + Args: + points : Coordinate of input 3D points, with shape [N, R, K, 3]. + isosurfaces : Isosurface scalars predicted by MPIPredictor. + levels: Predefined level set values. + + Returns: + intersections: The intersections between camera rays and the levels, + with shape [N, R, num_levels - 1, 3] + is_valid: Whether a level is valid or not, boolean tensor with shape + [N, R, num_levels - 1, 1] + """ + + s_l = isosurfaces[:, :, :-1] + s_h = isosurfaces[:, :, 1:] + + K = points.shape[2] + cost = torch.linspace(K - 1, 0, K - 1).float() + cost = cost.to(points.device).reshape(1, 1, -1, 1) + + x_interval = [] + s_interval = [] + for l in levels: + r = (s_h - l <= 0) * (l - s_l <= 0) * 2 - 1 + r = r * cost + _, indices = torch.max(r, dim=-2, keepdim=True) + x_l_select = torch.gather(points, -2, indices.expand(-1, -1, -1, 3)) + x_h_select = torch.gather(points, -2, indices.expand(-1, -1, -1, 3) + 1) + s_l_select = torch.gather(s_l, -2, indices) + s_h_select = torch.gather(s_h, -2, indices) + x_interval.append(torch.cat([x_l_select, x_h_select], dim=-2)) + s_interval.append(torch.cat([s_l_select, s_h_select], dim=-2)) + + intersections = [] + is_valid = [] + for interval, val, l in zip(x_interval, s_interval, levels): + x_l = interval[:, :, 0] + x_h = interval[:, :, 1] + s_l = val[:, :, 0] + s_h = val[:, :, 1] + scale = torch.where( + torch.abs(s_h - s_l) > 0.05, s_h - s_l, + torch.ones_like(s_h) * 0.05) + intersect = torch.where( + ((s_h - l <= 0) * (l - s_l <= 0)) & (torch.abs(s_h - s_l) > 0.05), + ((s_h - l) * x_l + (l - s_l) * x_h) / scale, x_h) + intersections.append(intersect) + is_valid.append(((s_h - l <= 0) * (l - s_l <= 0)).to(intersect.dtype)) + + intersections = torch.stack(intersections, dim=-2) + is_valid = torch.stack(is_valid, dim=-2) + + return intersections, is_valid diff --git a/source/rendering/point_sampler.py b/source/rendering/point_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a50d467fac92ed1bf5da564da0bdf34bd5eda8f5 --- /dev/null +++ b/source/rendering/point_sampler.py @@ -0,0 +1,1207 @@ +# python3.8 +"""Contains the functions to sample points in 3D space. + +Point sampling is primarily used for Neural Radiance Field (NeRF). + +Paper: https://arxiv.org/pdf/2003.08934.pdf +""" + +import numpy as np + +import torch +import torch.nn.functional as F + + +__all__ = ['PointSampler'] + +_POINT_SAMPLING_STRATEGIES = [ + 'uniform', 'normal', 'ray_dependent', 'point_dependent' +] + +_POINT_PERTURBING_STRATEGIES = [ + 'no', 'middle_uniform', 'uniform', 'self_uniform' +] + +_TENSOR_SAMPLING_STRATEGIES = [ + 'fix', 'uniform', 'normal', 'hybrid', 'truncated_normal' +] + + +class PointSampler(torch.nn.Module): + """Defines the class to help sample points. + + This class implements the `forward()` function for point sampling, which + includes the following steps: + + 1. Sample rays in the camera coordinate system. + 2. Sample points on each ray. + 3. Perturb points on each ray. + 4. Sample camera extrinsics. + 5. Transform points to the world coordinate system. + """ + + def __init__(self, + # Ray sampling related. + fov=30, + image_boundary_value=1.0, + focal=None, + x_axis_right=True, + y_axis_up=True, + z_axis_out=True, + x_pixel_shift=0.0, + y_pixel_shift=0.0, + selected_pixels=None, + patch_params=None, + # Point sampling (i.e., radial distance w.r.t. camera) related. + num_points=16, + point_strategy='uniform', + dis_min=None, + dis_max=None, + dis_mean=None, + dis_stddev=None, + per_ray_ref=None, + per_point_ref=None, + perturbation_strategy='middle_uniform', + # Camera sampling related. + radius_strategy='fix', + radius_fix=None, + radius_min=None, + radius_max=None, + radius_mean=None, + radius_stddev=None, + polar_strategy='uniform', + polar_fix=None, + polar_min=None, + polar_max=None, + polar_mean=None, + polar_stddev=None, + azimuthal_strategy='uniform', + azimuthal_fix=None, + azimuthal_min=None, + azimuthal_max=None, + azimuthal_mean=None, + azimuthal_stddev=None, + use_spherical_uniform_position=False, + pitch_strategy='fix', + pitch_fix=0, + pitch_min=None, + pitch_max=None, + pitch_mean=None, + pitch_stddev=None, + yaw_strategy='fix', + yaw_fix=0, + yaw_min=None, + yaw_max=None, + yaw_mean=None, + yaw_stddev=None, + roll_strategy='fix', + roll_fix=0, + roll_min=None, + roll_max=None, + roll_mean=None, + roll_stddev=None): + """Initializes hyper-parameters for point sampling. + + Detailed description of each argument can be found in functions + `get_ray_per_pixel()`, `sample_points_per_ray()`, + `perturb_points_per_ray()`, and `sample_camera_extrinsics()`. + """ + super().__init__() + self.fov = fov + self.image_boundary_value = image_boundary_value + self.focal = focal + self.x_axis_right = x_axis_right + self.y_axis_up = y_axis_up + self.z_axis_out = z_axis_out + self.x_pixel_shift = x_pixel_shift + self.y_pixel_shift = y_pixel_shift + self.selected_pixels = selected_pixels + self.patch_params = patch_params + + self.num_points = num_points + self.point_strategy = point_strategy + self.dis_min = dis_min + self.dis_max = dis_max + self.dis_mean = dis_mean + self.dis_stddev = dis_stddev + self.per_ray_ref = per_ray_ref + self.per_point_ref = per_point_ref + self.perturbation_strategy = perturbation_strategy + + self.radius_strategy = radius_strategy + self.radius_fix = radius_fix + self.radius_min = radius_min + self.radius_max = radius_max + self.radius_mean = radius_mean + self.radius_stddev = radius_stddev + self.polar_strategy = polar_strategy + self.polar_fix = polar_fix + self.polar_min = polar_min + self.polar_max = polar_max + self.polar_mean = polar_mean + self.polar_stddev = polar_stddev + self.azimuthal_strategy = azimuthal_strategy + self.azimuthal_fix = azimuthal_fix + self.azimuthal_min = azimuthal_min + self.azimuthal_max = azimuthal_max + self.azimuthal_mean = azimuthal_mean + self.azimuthal_stddev = azimuthal_stddev + self.use_spherical_uniform_position = use_spherical_uniform_position + self.pitch_strategy = pitch_strategy + self.pitch_fix = pitch_fix + self.pitch_min = pitch_min + self.pitch_max = pitch_max + self.pitch_mean = pitch_mean + self.pitch_stddev = pitch_stddev + self.yaw_strategy = yaw_strategy + self.yaw_fix = yaw_fix + self.yaw_min = yaw_min + self.yaw_max = yaw_max + self.yaw_mean = yaw_mean + self.yaw_stddev = yaw_stddev + self.roll_strategy = roll_strategy + self.roll_fix = roll_fix + self.roll_min = roll_min + self.roll_max = roll_max + self.roll_mean = roll_mean + self.roll_stddev = roll_stddev + + def forward(self, + batch_size, + image_size, + cam2world_matrix=None, + patch_grid=None, + **kwargs): + """Samples points. + + `K` denotes the number of points on each ray. + + Args: + batch_size: Batch size of images. Denoted as `N`. + image_size: Size of the image. One element indicates square image, + while two elements stand for height and width respectively. + Denoted as `H` and `W`. + cam2world_matrix: Transformation matrix used to transform the camera + coordinate system to the world coordinate system, with shape + [N, 4, 4]. If given, the process of sampling camera extrinsics + will be skipped. + patch_grid: A flow-field grid is used to sample patches from the + entire tensor, with shape [N, h, w, 2]. + **kwargs: Additional keyword arguments to override the variables + initialized in `__init__()`. + + Returns: + A dictionary, containing + - `camera_radius`: camera radius w.r.t. the world coordinate + system, with shape [N]. + - `camera_polar`: camera polar w.r.t. the world coordinate + system, with shape [N]. + - `camera_azimuthal`: camera azimuthal w.r.t. the world + coordinate system, with shape [N]. + - `camera_pitch`: camera pitch w.r.t. the camera coordinate + system, with shape [N]. + - `camera_yaw`: camera yaw w.r.t. the camera coordinate system, + with shape [N]. + - `camera_roll`: camera roll w.r.t. the camera coordinate + system, with shape [N]. + - `camera_pos`: camera position, i.e., the (x, y, z) coordinate + in the world coordinate system, with shape [N, 3]. + - `cam2world_matrix`: transformation matrix to transform the + camera coordinate system to the world coordinate system, + with shape [N, 4, 4]. + - `rays_camera`: ray directions in the camera coordinate system, + with shape [N, H, W, 3]. + - `rays_world`: ray directions in the world coordinate system, + with shape [N, H, W, 3]. + - `radii_raw`: raw per-point radial distance w.r.t. the camera + position, with shape [N, H, W, K]. + - `radii`: per-point radial distance after perturbation w.r.t. + the camera position, with shape [N, H, W, K]. + - `points_camera`: per-point coordinate in the camera coordinate + system, with shape [N, H, W, K, 3]. + - `points_world`: per-point coordinate in the world coordinate + system, with shape [N, H, W, K, 3]. + """ + fov = kwargs.get('fov', self.fov) + focal = kwargs.get('focal', self.focal) + image_boundary_value = kwargs.get( + 'image_boundary_value', self.image_boundary_value) + x_axis_right = kwargs.get('x_axis_right', self.x_axis_right) + y_axis_up = kwargs.get('y_axis_up', self.y_axis_up) + z_axis_out = kwargs.get('z_axis_out', self.z_axis_out) + x_pixel_shift = kwargs.get('x_pixel_shift', self.x_pixel_shift) + y_pixel_shift = kwargs.get('y_pixel_shift', self.y_pixel_shift) + selected_pixels = kwargs.get('selected_pixels', self.selected_pixels) + patch_params = kwargs.get('patch_params', self.patch_params) + + num_points = kwargs.get('num_points', self.num_points) + point_strategy = kwargs.get( + 'point_strategy', self.point_strategy) + dis_min = kwargs.get('dis_min', self.dis_min) + dis_max = kwargs.get('dis_max', self.dis_max) + dis_mean = kwargs.get('dis_mean', self.dis_mean) + dis_stddev = kwargs.get('dis_stddev', self.dis_stddev) + per_ray_ref = kwargs.get('per_ray_ref', self.per_ray_ref) + per_point_ref = kwargs.get('per_point_ref', self.per_point_ref) + perturbation_strategy = kwargs.get( + 'perturbation_strategy', self.perturbation_strategy) + + radius_strategy = kwargs.get('radius_strategy', self.radius_strategy) + radius_fix = kwargs.get('radius_fix', self.radius_fix) + radius_min = kwargs.get('radius_min', self.radius_min) + radius_max = kwargs.get('radius_max', self.radius_max) + radius_mean = kwargs.get('radius_mean', self.radius_mean) + radius_stddev = kwargs.get('radius_stddev', self.radius_stddev) + polar_strategy = kwargs.get('polar_strategy', self.polar_strategy) + polar_fix = kwargs.get('polar_fix', self.polar_fix) + polar_min = kwargs.get('polar_min', self.polar_min) + polar_max = kwargs.get('polar_max', self.polar_max) + polar_mean = kwargs.get('polar_mean', self.polar_mean) + polar_stddev = kwargs.get('polar_stddev', self.polar_stddev) + azimuthal_strategy = kwargs.get( + 'azimuthal_strategy', self.azimuthal_strategy) + azimuthal_fix = kwargs.get('azimuthal_fix', self.azimuthal_fix) + azimuthal_min = kwargs.get('azimuthal_min', self.azimuthal_min) + azimuthal_max = kwargs.get('azimuthal_max', self.azimuthal_max) + azimuthal_mean = kwargs.get('azimuthal_mean', self.azimuthal_mean) + azimuthal_stddev = kwargs.get('azimuthal_stddev', self.azimuthal_stddev) + use_spherical_uniform_position = kwargs.get( + 'use_spherical_uniform_position', + self.use_spherical_uniform_position) + pitch_strategy = kwargs.get('pitch_strategy', self.pitch_strategy) + pitch_fix = kwargs.get('pitch_fix', self.pitch_fix) + pitch_min = kwargs.get('pitch_min', self.pitch_min) + pitch_max = kwargs.get('pitch_max', self.pitch_max) + pitch_mean = kwargs.get('pitch_mean', self.pitch_mean) + pitch_stddev = kwargs.get('pitch_stddev', self.pitch_stddev) + yaw_strategy = kwargs.get('yaw_strategy', self.yaw_strategy) + yaw_fix = kwargs.get('yaw_fix', self.yaw_fix) + yaw_min = kwargs.get('yaw_min', self.yaw_min) + yaw_max = kwargs.get('yaw_max', self.yaw_max) + yaw_mean = kwargs.get('yaw_mean', self.yaw_mean) + yaw_stddev = kwargs.get('yaw_stddev', self.yaw_stddev) + roll_strategy = kwargs.get('roll_strategy', self.roll_strategy) + roll_fix = kwargs.get('roll_fix', self.roll_fix) + roll_min = kwargs.get('roll_min', self.roll_min) + roll_max = kwargs.get('roll_max', self.roll_max) + roll_mean = kwargs.get('roll_mean', self.roll_mean) + roll_stddev = kwargs.get('roll_stddev', self.roll_stddev) + + rays_camera = get_ray_per_pixel( + batch_size=batch_size, + image_size=image_size, + fov=fov, + boundary=image_boundary_value, + focal=focal, + x_axis_right=x_axis_right, + y_axis_up=y_axis_up, + z_axis_out=z_axis_out, + x_pixel_shift=x_pixel_shift, + y_pixel_shift=y_pixel_shift, + selected_pixels=selected_pixels, + patch_params=patch_params) + + if patch_grid is not None: + rays_camera = rays_camera.permute(0, 3, 1, 2) + rays_camera = F.grid_sample(rays_camera, + patch_grid, + mode='bilinear', + align_corners=True) + rays_camera = rays_camera.permute(0, 2, 3, 1) + _, h, w, _ = patch_grid.shape + image_size = h if h == w else (h, w) + + if selected_pixels is not None: + h, w = selected_pixels.shape[1:3] + image_size = h if h == w else (h, w) + + radii_raw = sample_points_per_ray(batch_size=batch_size, + image_size=image_size, + num_points=num_points, + strategy=point_strategy, + dis_min=dis_min, + dis_max=dis_max, + dis_mean=dis_mean, + dis_stddev=dis_stddev, + per_ray_ref=per_ray_ref, + per_point_ref=per_point_ref) + radii = perturb_points_per_ray(radii=radii_raw, + strategy=perturbation_strategy) + + if cam2world_matrix is not None: + camera_info = { + 'radius': None, + 'polar': None, + 'azimuthal': None, + 'pitch': None, + 'yaw': None, + 'roll': None, + 'camera_pos': None, + 'cam2world_matrix': cam2world_matrix, + } + else: + camera_info = sample_camera_extrinsics( + batch_size=batch_size, + radius_strategy=radius_strategy, + radius_fix=radius_fix, + radius_min=radius_min, + radius_max=radius_max, + radius_mean=radius_mean, + radius_stddev=radius_stddev, + polar_strategy=polar_strategy, + polar_fix=polar_fix, + polar_min=polar_min, + polar_max=polar_max, + polar_mean=polar_mean, + polar_stddev=polar_stddev, + azimuthal_strategy=azimuthal_strategy, + azimuthal_fix=azimuthal_fix, + azimuthal_min=azimuthal_min, + azimuthal_max=azimuthal_max, + azimuthal_mean=azimuthal_mean, + azimuthal_stddev=azimuthal_stddev, + use_spherical_uniform_position=use_spherical_uniform_position, + pitch_strategy=pitch_strategy, + pitch_fix=pitch_fix, + pitch_min=pitch_min, + pitch_max=pitch_max, + pitch_mean=pitch_mean, + pitch_stddev=pitch_stddev, + yaw_strategy=yaw_strategy, + yaw_fix=yaw_fix, + yaw_min=yaw_min, + yaw_max=yaw_max, + yaw_mean=yaw_mean, + yaw_stddev=yaw_stddev, + roll_strategy=roll_strategy, + roll_fix=roll_fix, + roll_min=roll_min, + roll_max=roll_max, + roll_mean=roll_mean, + roll_stddev=roll_stddev) + + points = get_point_coord( + rays_camera=rays_camera, + radii=radii, + cam2world_matrix=camera_info['cam2world_matrix']) + + return { + 'camera_radius': camera_info['radius'], # [N] + 'camera_polar': camera_info['polar'], # [N] + 'camera_azimuthal': camera_info['azimuthal'], # [N] + 'camera_pitch': camera_info['pitch'], # [N] + 'camera_yaw': camera_info['yaw'], # [N] + 'camera_roll': camera_info['roll'], # [N] + 'camera_pos': camera_info['camera_pos'], # [N, 3] + 'cam2world_matrix': camera_info['cam2world_matrix'], # [N, 4, 4] + 'rays_camera': rays_camera, # [N, H, W, 3] + 'rays_world': points['rays_world'], # [N, H, W, 3] + 'radii_raw': radii_raw, # [N, H, W, K] + 'radii': radii, # [N, H, W, K] + 'points_camera': points['points_camera'], # [N, H, W, K, 3] + 'points_world': points['points_world'], # [N, H, W, K, 3] + } + + +def get_ray_per_pixel(batch_size, + image_size, + fov, + boundary=1.0, + focal=None, + x_axis_right=True, + y_axis_up=True, + z_axis_out=True, + x_pixel_shift=0.0, + y_pixel_shift=0.0, + selected_pixels=None, + patch_params=None, + normalize=True): + """Gets ray direction for each image pixel under camera coordinate system. + + Each ray direction is represented by a vector, [x, y, z], under the + following coordinate system: + + - The origin is set at the camera position. + - The X axis is set as the horizontal direction of the image plane, with + `x_axis_right` controlling whether the positive direction points to the + right hand side. + - The Y axis is set as the vertical direction of the image plane, with + `y_axis_up` controlling whether the positive direction points to the + upside. + - The Z axis is set as the direction perpendicular to the image plane, with + `z_axis_out` controlling whether the positive direction points to the + outside. If true, then under the camera coordinate system, the z + coordinate of the image plane is negative. + - By default, where `x_axis_right`, `y_axis_up`, and `z_axis_out` are all + set as `True`, the above coordinate system is a right-hand one. + + Taking a 5x5 image (with boundary 1.0) as an instance, the per-pixel (x, y) + coordinates (with `x_axis_right = True` and `y_axis_up = True`) should look + like: + + (-1.0, 1.0) (-0.5, 1.0) (0.0, 1.0) (0.5, 1.0) (1.0, 1.0) + (-1.0, 0.5) (-0.5, 0.5) (0.0, 0.5) (0.5, 0.5) (1.0, 0.5) + (-1.0, 0.0) (-0.5, 0.0) (0.0, 0.0) (0.5, 0.0) (1.0, 0.0) + (-1.0, -0.5) (-0.5, -0.5) (0.0, -0.5) (0.5, -0.5) (1.0, -0.5) + (-1.0, -1.0) (-0.5, -1.0) (0.0, -1.0) (0.5, -1.0) (1.0, -1.0) + + In the above case, if `x_pixel_shift = 0.5`, which means all rays will be + sampled at middle-pixels along the X axis, the X shift value will be + + 2 * boundary / (W - 1) * 0.5 = 2 * 1 / 4 * 0.5 = 0.25 + + Then, the per-pixel (x, y) coordinates should look like: + + (-0.75, 1.0) (-0.25, 1.0) (0.25, 1.0) (0.75, 1.0) (1.25, 1.0) + (-0.75, 0.5) (-0.25, 0.5) (0.25, 0.5) (0.75, 0.5) (1.25, 0.5) + (-0.75, 0.0) (-0.25, 0.0) (0.25, 0.0) (0.75, 0.0) (1.25, 0.0) + (-0.75, -0.5) (-0.25, -0.5) (0.25, -0.5) (0.75, -0.5) (1.25, -0.5) + (-0.75, -1.0) (-0.25, -1.0) (0.25, -1.0) (0.75, -1.0) (1.25, -1.0) + + NOTE: + The X-axis focal and Y-axis focal are assumed to be the same according + to the pinhole camera model. + + Args: + batch_size: Batch size of images, each of which has the same ray + directions. Denoted as `N`. + image_size: Size of the image. One element indicates square image, while + two elements stand for height and width respectively. Denoted as `H` + and `W`. + fov: Field of view (along X axis) of the camera, in unit of degree. + boundary: The maximum value of the X coordinate. Defaults to `1.0`. + focal: Focal Length of camera. If not given, it will be calculated by + `fov` and `boundary` automatically. Note that focal is assumed to + be normalized by image size. Defaults to `None`. + x_axis_right: Whether the positive direction of X axis points to the + right hand side. Defaults to `True`. + y_axis_up: Whether the positive direction of Y axis points to the + upside. Defaults to `True`. + z_axis_out: Whether the positive direction of Z axis points to the + outside. Defaults to `True`. + x_pixel_shift: Pixel shift of each ray along X axis. Defaults to `0.0`. + y_pixel_shift: Pixel shift of each ray along Y axis. Defaults to `0.0`. + selected_pixels: Indices of a subset of pixels from which to sample + rays, with shape [N, h, w]. If not given, all pixels will be + attached a ray. Defaults to `None`. + patch_params: Dictionary containing + - `scales`: scales of the sampling patch, with shape [N, 2]; + - `offsets`: offsets of the sampling patch, with shape [N, 2]. + normalize: Whether to normalize the sampled coordinates in pixel sapce. + Recall the formula of camera ray in NeRF paper : `r = o + t*d`. + If `normalize` is set `True`, `t` represents the radial distance; + otherwise `t` represents the depth. + + Returns: + A tensor, with shape [N, H, W, 3] (or [N, h, w, 3] if `selected_pixels` + is given), representing the per-pixel ray direction. Each direction + is normalized to a unit vector. + """ + # Check inputs. + assert isinstance(batch_size, int) and batch_size > 0 + N = batch_size + assert isinstance(image_size, (int, list, tuple)) + if isinstance(image_size, int): + H = image_size + W = image_size + else: + H, W = image_size + assert isinstance(H, int) and H > 0 + assert isinstance(W, int) and W > 0 + assert 0 < fov < 180 + assert boundary > 0 + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + # Get (x, y) grid by boundary. + max_x = boundary + if H == W: + max_y = boundary + else: + max_y = boundary / W * H + y, x = torch.meshgrid(torch.linspace(max_y, -max_y, H, device=device), + torch.linspace(-max_x, max_x, W, device=device), + indexing='ij') + + if patch_params is not None: + # This patching sampling strategy is particularly used in EpiGRAF, + # which includes the following steps: + # 1). Shift [-1, 1] range into [0, 2]; + # 2). Multiply by the patch size; + # 3). Shift back to [-1, 1]; + # 4). Apply the offset (converted from [0, 1] to [0, 2]). + patch_scales = patch_params['scales'] + patch_offsets = patch_params['offsets'] + x = x.flatten().unsqueeze(0).repeat(N, 1) # [N, H * W] + y = y.flatten().unsqueeze(0).repeat(N, 1) # [N, H * W] + x = (x + 1.0) * patch_scales[:, 0].view( + N, 1) - 1.0 + patch_offsets[:, 0].view(N, 1) * 2.0 # [N, H * W] + y = (y + 1.0) * patch_scales[:, 1].view( + N, 1) - 1.0 + patch_offsets[:, 1].view(N, 1) * 2.0 # [N, H * W] + + # Get z coordinate of the image plane by focal (i.e., FOV). + if focal is None: + focal = boundary / np.tan((fov / 180 * np.pi) / 2) + z = -focal * torch.ones_like(x) # [H, W] + + # Adjust the positive direction of each axis. + if not x_axis_right: + x = -x + if not y_axis_up: + y = -y + if not z_axis_out: + z = -z + + # Adjust pixel shift along X and Y axes. + x = x + 2 * max_x / (W - 1) * x_pixel_shift + y = y + 2 * max_y / (H - 1) * y_pixel_shift + + if normalize: + # Normalize directions to unit vectors. + rays = F.normalize(torch.stack([x, y, z], dim=-1), dim=-1) # [H, W, 3] + else: + rays = torch.stack([x / focal, y / focal, z / focal], + dim=-1) # [H, W, 3] + + # Repeat the sampled rays along the batch dimension. + if rays.shape == (N, H * W, 3): + rays = rays.reshape(N, H, W, 3) + else: + rays = rays.unsqueeze(0).repeat(N, 1, 1, 1) # [N, H, W, 3] + + # Select rays of interest if needed. + if selected_pixels is not None: + assert selected_pixels.ndim == 3 and selected_pixels.shape[0] == N + _, h, w = selected_pixels.shape + indices = selected_pixels.reshape(N, h * w, 1).repeat(1, 1, 3) + rays = rays.reshape(N, H * W, 3) + rays = torch.gather(rays, dim=1, index=indices) # [N, h * w, 3] + rays = rays.reshape(N, h, w, 3) + + return rays + + +def sample_points_per_ray(batch_size, + image_size, + num_points, + strategy='uniform', + dis_min=None, + dis_max=None, + dis_mean=None, + dis_stddev=None, + per_ray_ref=None, + per_point_ref=None): + """Samples per-point radial distance on each ray. + + This function is independent of ray directions, hence, each point is + represent by a number, indicating its radial distance to the origin (i.e., + the camera). + + The following sampling strategies are supported: + + - `uniform`: + For each ray, the points uniformly locate in range `[dis_min, dis_max]`. + + - `normal`: + For each ray, the points are sampled subject to + `Gaussian(dis_mean, dis_stddev^2)`. + + - `ray_dependent`: + Each ray follows a separate strategy, controlled by `per_ray_ref`. + + - `point_dependent`: + Each point follows a separate strategy, controlled by `per_point_ref`. + + Args: + batch_size: Batch size of images, for which points are sampled + independently. Denoted as `N`. + image_size: Size of the image. One element indicates square image, while + two elements stand for height and width respectively. Denoted as `H` + and `W`. + num_points: Number of points sampled on each ray. Denoted as `K`. + strategy: Strategy for point sampling. Defaults to `uniform`. + dis_min: Minimum radial distance (with camera as the origin) for each + point. Defaults to `None`. + dis_max: Maximum radial distance (with camera as the origin) for each + point. Defaults to `None`. + dis_mean: Mean radial distance (with camera as the origin) for each + point. Defaults to `None`. + dis_stddev: Standard deviation of the radial distance (with camera as + the origin) for each point. Defaults to `None`. + per_ray_ref: Reference for each ray, which will guide the sampling + process. Shape [N, H, W, c] is expected, where `c` is the dimension + of a single reference. Defaults to `None`. + per_point_ref: Reference for each point, which will guide the sampling + process. Shape [N, H, W, K, c] is expected, where `c` is the + dimension of a single reference. Defaults to `None`. + + Returns: + A tensor, with shape [N, H, W, K], representing the per-point radial + distance on each ray. All numbers should be positive, and the + distances on each ray should follow a non-descending order. + + Raises: + ValueError: If the sampling strategy is not supported. + NotImplementedError: If the sampling strategy is not implemented. + """ + # Check inputs. + assert isinstance(batch_size, int) and batch_size > 0 + N = batch_size + assert isinstance(image_size, (int, list, tuple)) + if isinstance(image_size, int): + H = image_size + W = image_size + else: + H, W = image_size + assert isinstance(H, int) and H > 0 + assert isinstance(W, int) and W > 0 + assert isinstance(num_points, int) and num_points > 0 + K = num_points + strategy = strategy.lower() + if strategy not in _POINT_SAMPLING_STRATEGIES: + raise ValueError(f'Invalid point sampling strategy: `{strategy}`!\n' + f'Strategies allowed: {_POINT_SAMPLING_STRATEGIES}.') + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + # Sample points according to strategy. + if strategy == 'uniform': + assert dis_max >= dis_min > 0 + radii = torch.linspace(dis_min, dis_max, K, device=device) # [K] + return radii.reshape(1, 1, 1, K).repeat(N, H, W, 1) # [N, H, W, K] + + if strategy == 'normal': + assert dis_mean > 0 and dis_stddev >= 0 + + if strategy == 'ray_dependent': + assert per_ray_ref.ndim == 4 + assert per_ray_ref.shape[:3] == (N, H, W) + + if strategy == 'point_dependent': + assert per_point_ref.ndim == 5 + assert per_point_ref.shape[:4] == (N, H, W, K) + + raise NotImplementedError(f'Not implemented point sampling strategy: ' + f'`{strategy}`!') + + +def perturb_points_per_ray(radii, strategy='middle_uniform'): + # Stratified sampling approach described in original NeRF paper. + """Perturbs point radii within their local range on each ray. + + `N`, `H`, `W`, `K` denote batch size, image height, image width, number of + points per ray, respectively. + + The following perturbing strategies are supported: + + - `no`: + Disable point perturbation. + + - `middle_uniform`: + For each point, it is perturbed between two midpoints. One locates + within the point itself and its previous one on the same ray, while the + other locates within the point itself and its next one on the same ray. + + - `uniform`: + For each point, it is perturbed between itself and its next one. + For example, there are `n+1` points on the ray: [x_0, x_1, ..., x_n]. + Then the perturbed points are [x_0', x_1', ..., x_n'] with distribution + xi' ~ U(x_i, x_i+1), where x_n+1 = x_n + (x_n - x_n-1). + + - `self_uniform`: + For each point, it is perturbed around itself. For example, there are + `n+1` points on the ray: [x_0, x_1, ..., x_n]. Then the perturbed points + are [x_0', x_1', ..., x_n'] with distribution + xi' ~ U(x_i - 0.5, x_i+1 - 0.5). + + Args: + radii: A collection of point radii, with shape [N, H, W, K]. + strategy: Strategy to perturb each point. Defaults to `middle_uniform`. + + Returns: + A tensor, with shape [N, H, W, K], representing the per-point radial + distance on each ray. All numbers should be positive, and the + distances on each ray should follow a non-descending order. + + Raises: + ValueError: If the input point radii are with invalid shape, or the + perturbing strategy is not supported. + NotImplementedError: If the perturbing strategy is not implemented. + """ + # Check inputs. + if radii.ndim != 4: + raise ValueError(f'The input point radii should be with shape ' + f'[batch_size, height, width, num_points], ' + f'but `{radii.shape}` is received!') + strategy = strategy.lower() + if strategy not in _POINT_PERTURBING_STRATEGIES: + raise ValueError(f'Invalid point perturbing strategy: `{strategy}`!\n' + f'Strategies allowed: {_POINT_PERTURBING_STRATEGIES}.') + + if strategy == 'no': + return radii + + if strategy == 'middle_uniform': + # Get midpoints. + midpoint = (radii[..., 1:] + radii[..., :-1]) / 2 # [N, H, W, K - 1] + # Get intervals. + left = torch.cat([radii[..., :1], midpoint], dim=-1) # [N, H, W, K] + right = torch.cat([midpoint, radii[..., -1:]], dim=-1) # [N, H, W, K] + # Uniformly sample within each interval. + t = torch.rand_like(radii) # [N, H, W, K] + return left + (right - left) * t # [N, H, W, K] + + if strategy == 'uniform': + delta = radii[..., 1:2] - radii[..., 0:1] # [N, H, W, 1] + t = torch.rand_like(radii) # [N, H, W, K] + return radii + t * delta # [N, H, W, K] + + if strategy == 'self_uniform': + delta = radii[..., 1:2] - radii[..., 0:1] # [N, H, W, 1] + t = torch.rand_like(radii) - 0.5 # [N, H, W, K] + return radii + t * delta # [N, H, W, K] + + raise NotImplementedError(f'Not implemented point perturbing strategy: ' + f'`{strategy}`!') + + +def sample_tensor(size, + strategy='uniform', + entry_fix=None, + entry_min=None, + entry_max=None, + entry_mean=None, + entry_stddev=None): + """Samples a tensor according to specified strategy. + + The following sampling strategies are supported: + + - `fix`: + Each entry is fixed as `entry_fix`. + + - `uniform`: + Each entry is uniformly sampled from range `[entry_min, entry_max]`. + + - `normal`: + Each entry is sampled subject to `Gaussian(entry_mean, entry_stddev^2)`. + + - `hybrid`: + Each entry is 50% sampled with `uniform` and 50% sampled with `normal`. + + - `truncated_normal`: + Each entry is sampled subject to a truncated normal distribution, with + `entry_min` and `entry_max` as the cut-off values. + + + Args: + size: Size of the sampled tensor. This field is expected to be an + integer, a list, or a tuple. + strategy: Strategy to sample points. Defaults to `uniform`. + entry_fix: Fixed value of the entry. Defaults to `None`. + entry_min: Minimum value of each entry. Defaults to `None`. + entry_max: Maximum value of each entry. Defaults to `None`. + entry_mean: Mean value of each entry. Defaults to `None`. + entry_stddev: Standard deviation of each entry. Defaults to `None`. + + Returns: + A tensor, with expected size. + + Raises: + ValueError: If the sampling strategy is not supported. + NotImplementedError: If the sampling strategy is not implemented. + """ + # Check inputs. + if isinstance(size, int): + size = (size,) + elif isinstance(size, list): + size = tuple(size) + assert isinstance(size, tuple) + strategy = strategy.lower() + if strategy not in _TENSOR_SAMPLING_STRATEGIES: + raise ValueError(f'Invalid tensor sampling strategy: `{strategy}`!\n' + f'Strategies allowed: {_TENSOR_SAMPLING_STRATEGIES}.') + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + if strategy == 'fix': + assert entry_fix is not None + return torch.ones(size, device=device) * entry_fix + + if strategy == 'uniform': + assert entry_max >= entry_min + t = torch.rand(size, device=device) + return entry_min + (entry_max - entry_min) * t + + if strategy == 'normal': + assert entry_mean is not None and entry_stddev >= 0 + return torch.randn(size, device=device) * entry_stddev + entry_mean + + if strategy == 'hybrid': + assert entry_max >= entry_min + assert entry_mean is not None and entry_stddev >= 0 + if np.random.random() < 0.5: + t = torch.rand(size, device=device) + return entry_min + (entry_max - entry_min) * t + return torch.randn(size, device=device) * entry_stddev + entry_mean + + if strategy == 'truncated_normal': + assert entry_max >= entry_min + assert entry_mean is not None and entry_stddev >= 0 + tensor = torch.randn(size, device=device) * entry_stddev + entry_mean + tensor = torch.clamp(tensor, entry_min, entry_max) + return tensor + + raise NotImplementedError(f'Not implemented tensor sampling strategy: ' + f'`{strategy}`!') + + +def sample_camera_extrinsics(batch_size, + radius_strategy='fix', + radius_fix=None, + radius_min=None, + radius_max=None, + radius_mean=None, + radius_stddev=None, + polar_strategy='uniform', + polar_fix=None, + polar_min=None, + polar_max=None, + polar_mean=None, + polar_stddev=None, + azimuthal_strategy='uniform', + azimuthal_fix=None, + azimuthal_min=None, + azimuthal_max=None, + azimuthal_mean=None, + azimuthal_stddev=None, + use_spherical_uniform_position=False, + pitch_strategy='fix', + pitch_fix=0, + pitch_min=None, + pitch_max=None, + pitch_mean=None, + pitch_stddev=None, + yaw_strategy='fix', + yaw_fix=0, + yaw_min=None, + yaw_max=None, + yaw_mean=None, + yaw_stddev=None, + roll_strategy='fix', + roll_fix=0, + roll_min=None, + roll_max=None, + roll_mean=None, + roll_stddev=None, + y_axis_up=True): + """Samples camera extrinsics. + + This function supports sampling camera extrinsics from 6 dimensions (here, + all angles are in unit of radian): + + - Camera position: + - radius: Distance from the camera position to the origin of the world + coordinate system. + - polar: The polar angle with respect to the origin of the world + coordinate system. + - azimuthal: The azimuthal angle with respect to the origin of the world + coordinate system. + - Camera orientation: + - pitch: Pitch angle (X axis) regarding the camera coordinate system. + - yaw: Yaw angle (Y axis) regarding the camera coordinate system. + - roll: Roll angle (Z axis) regarding the camera coordinate system. + + and then convert the camera extrinsics to camera position and coordinate + transformation matrix. + + Currently, our framework supports two kind of world coordinate systems: + + (1) Y-axis pointing upward (default, y_upward=True): + v: polar + Y u: azimuth + ^ + |v / + | / + |/ + +---------> X + /\ + / \ + / u \ + Z + + (1) Z-axis pointing upward (y_upward=False): + v: polar + Z u: azimuth + ^ + |v / + | / + |/ + +---------> Y + /\ + / \ + / u \ + X + + More details about sampling as well as arguments can be found in function + `sample_tensor()`. + + NOTE: + Without camera orientation (i.e., `pitch = 0, yaw = 0, roll = 0`), this + function assumes the camera pointing to the origin of the world + coordinate system. Furthermore, camera orientation controls the rotation + within the camera coordinate system, which is independent of the + transformation across coordinate systems. As a result, the camera does + not necessarily point to the origin of the world coordinate system + anymore. + + Args: + batch_size: Batch size of the sampled camera. Denoted as `N`. + use_spherical_uniform_position: Whether to sample the camera position + subject to a spherical uniform distribution. Defaults to False. + + Returns: + A dictionary, containing + - `camera_radius`: camera radius w.r.t. the world coordinate system, + with shape [N]. + - `camera_polar`: camera polar w.r.t. the world coordinate system, + with shape [N]. + - `camera_azimuthal`: camera azimuthal w.r.t. the world coordinate + system, with shape [N]. + - `camera_pitch`: camera pitch w.r.t. the camera coordinate system, + with shape [N]. + - `camera_yaw`: camera yaw w.r.t. the camera coordinate system, + with shape [N]. + - `camera_roll`: camera roll w.r.t. the camera coordinate system, + with shape [N]. + - `camera_pos`: camera position, i.e., the (x, y, z) coordinate + in the world coordinate system, with shape [N, 3]. + - `cam2world_matrix`: transformation matrix to transform the camera + coordinate system to the world coordinate system, with shape + [N, 4, 4]. + """ + # Sample camera position. + radius = sample_tensor(size=batch_size, + strategy=radius_strategy, + entry_fix=radius_fix, + entry_min=radius_min, + entry_max=radius_max, + entry_mean=radius_mean, + entry_stddev=radius_stddev) + + if use_spherical_uniform_position: + azimuthal = sample_tensor( + size=batch_size, + strategy='uniform', + entry_min=azimuthal_min if azimuthal_min is not None else 0, + entry_max=azimuthal_max if azimuthal_max is not None else 1) + azimuthal = (azimuthal - 0.5) * 2 * azimuthal_stddev + azimuthal_mean + + polar_mean = polar_mean / np.pi + polar_stddev = polar_stddev / np.pi + polar = sample_tensor( + size=batch_size, + strategy='uniform', + entry_min=polar_min if polar_min is not None else 0, + entry_max=polar_max if polar_max is not None else 1) + polar = (polar - 0.5) * 2 * polar_stddev + polar_mean + polar = torch.arccos(1 - 2 * polar) + else: + polar = sample_tensor(size=batch_size, + strategy=polar_strategy, + entry_fix=polar_fix, + entry_min=polar_min, + entry_max=polar_max, + entry_mean=polar_mean, + entry_stddev=polar_stddev) + azimuthal = sample_tensor(size=batch_size, + strategy=azimuthal_strategy, + entry_fix=azimuthal_fix, + entry_min=azimuthal_min, + entry_max=azimuthal_max, + entry_mean=azimuthal_mean, + entry_stddev=azimuthal_stddev) + + # Sample camera orientation. + pitch = sample_tensor(size=batch_size, + strategy=pitch_strategy, + entry_fix=pitch_fix, + entry_min=pitch_min, + entry_max=pitch_max, + entry_mean=pitch_mean, + entry_stddev=pitch_stddev) + yaw = sample_tensor(size=batch_size, + strategy=yaw_strategy, + entry_fix=yaw_fix, + entry_min=yaw_min, + entry_max=yaw_max, + entry_mean=yaw_mean, + entry_stddev=yaw_stddev) + roll = sample_tensor(size=batch_size, + strategy=roll_strategy, + entry_fix=roll_fix, + entry_min=roll_min, + entry_max=roll_max, + entry_mean=roll_mean, + entry_stddev=roll_stddev) + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + # Get camera position. + N = batch_size + camera_pos = torch.zeros((N, 3), device=device) + if y_axis_up: + camera_pos[:, 0] = radius * torch.sin(polar) * torch.cos(azimuthal) + camera_pos[:, 1] = radius * torch.cos(polar) + camera_pos[:, 2] = radius * torch.sin(polar) * torch.sin(azimuthal) + else: + camera_pos[:, 0] = radius * torch.sin(polar) * torch.cos(azimuthal) + camera_pos[:, 1] = radius * torch.sin(polar) * torch.sin(azimuthal) + camera_pos[:, 2] = radius * torch.cos(polar) + + # Get transformation matrix with the following steps. + # 1. Use pitch, yaw, and roll to get the rotation matrix within the camera + # coordinate system. + # 2. Get the forward axis, which points from the camper position to the + # origin of the world coordinate system. + # 3. Get a "pseudo" up axis, which is [0, 1, 0]. + # 4. Get the left axis by crossing the "pseudo" up axis with the forward + # axis. + # 5. Get the "actual" up axis by crossing the forward axis with the left + # axis. + # 6. Get the camera-to-world rotation matrix with the aforementioned + # forward axis, left axis, and "actual" up axis. + # 7. Get the camera-to-world transformation matrix. + pitch_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(N, 1, 1) + pitch_matrix[:, 1, 1] = torch.cos(pitch) + pitch_matrix[:, 2, 2] = torch.cos(pitch) + pitch_matrix[:, 1, 2] = -torch.sin(pitch) + pitch_matrix[:, 2, 1] = torch.sin(pitch) # [N, 4, 4] + yaw_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(N, 1, 1) + yaw_matrix[:, 0, 0] = torch.cos(yaw) + yaw_matrix[:, 2, 2] = torch.cos(yaw) + yaw_matrix[:, 2, 0] = -torch.sin(yaw) + yaw_matrix[:, 0, 2] = torch.sin(yaw) # [N, 4, 4] + roll_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(N, 1, 1) + roll_matrix[:, 0, 0] = torch.cos(roll) + roll_matrix[:, 1, 1] = torch.cos(roll) + roll_matrix[:, 0, 1] = -torch.sin(roll) + roll_matrix[:, 1, 0] = torch.sin(roll) # [N, 4, 4] + + + forward_axis = F.normalize(camera_pos * -1, dim=-1) # [N, 3] + if y_axis_up: + pseudo_up_axis = torch.as_tensor([0.0, 1.0, 0.0], device=device) # [3] + else: + pseudo_up_axis = torch.as_tensor([0.0, 0.0, 1.0], device=device) # [3] + pseudo_up_axis = pseudo_up_axis.reshape(1, 3).repeat(N, 1) # [N, 3] + left_axis = torch.cross(pseudo_up_axis, forward_axis, dim=-1) # [N, 3] + left_axis = F.normalize(left_axis, dim=-1) # [N, 3] + up_axis = torch.cross(forward_axis, left_axis, dim=-1) # [N, 3] + up_axis = F.normalize(up_axis, dim=-1) # [N, 3] + + rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(N, 1, 1) + rotation_matrix[:, :3, 0] = -left_axis + rotation_matrix[:, :3, 1] = up_axis + rotation_matrix[:, :3, 2] = -forward_axis # [N, 4, 4] + + translation_matrix = torch.eye(4, device=device) + translation_matrix = translation_matrix.unsqueeze(0).repeat(N, 1, 1) + translation_matrix[:, :3, 3] = camera_pos # [N, 4, 4] + + cam2world_matrix = (translation_matrix @ rotation_matrix @ + roll_matrix @ yaw_matrix @ pitch_matrix) # [N, 4, 4] + + return { + 'radius': radius, + 'polar': polar, + 'azimuthal': azimuthal, + 'pitch': pitch, + 'yaw': yaw, + 'roll': roll, + 'camera_pos': camera_pos, + 'cam2world_matrix': cam2world_matrix + } + + +def get_point_coord(rays_camera, radii, cam2world_matrix): + """Gets pre-point coordinate in the world coordinate system. + + `N`, `H`, `W`, `K` denote batch size, image height, image width, number of + points per ray, respectively. + + Args: + rays_camera: Per-pixel ray direction, with shape [N, H, W, 3], in the + camera coordinate system. + radii: Per-point radial distance on each ray, with shape [N, H, W, K]. + cam2world_matrix: Transformation matrix that transforms the camera + coordinate system to the world coordinate system, with shape + [N, 4, 4]. + + Returns: + A dictionary, containing + - `rays_world`: ray directions in the world coordinate system, + with shape [N, H, W, 3]. + - `ray_origins_world`: ray origins in the world coordinate system, + with shape [N, H, W, 3]. + - `points_camera`: per-point coordinate in the camera coordinate + system, with shape [N, H, W, K, 3]. + - `points_world`: per-point coordinate in the world coordinate + system, with shape [N, H, W, K, 3]. + + Raises: + ValueError: If any input has invalid shape. + """ + # Check inputs. + if rays_camera.ndim != 4 or rays_camera.shape[3] != 3: + raise ValueError(f'The input rays should be with shape ' + f'[batch_size, height, width, 3], ' + f'but `{rays_camera.shape}` is received!') + N, H, W, _ = rays_camera.shape + if radii.ndim != 4 or radii.shape[:3] != (N, H, W): + raise ValueError(f'The input radii should be with shape ' + f'[batch_size, height, width, num_points], where ' + f'batch_size, height, width align with those of rays, ' + f'but `{radii.shape}` is received!') + K = radii.shape[3] + if cam2world_matrix.shape != (N, 4, 4): + raise ValueError(f'The input cam2world_matrix should be with shape ' + f'[batch_size, 4, 4], where batch_size align with ' + f'that of rays and radii ' + f'but `{cam2world_matrix.shape}` is received!') + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + # Transform rays. + rays_world = (cam2world_matrix[:, :3, :3] @ + rays_camera.reshape(N, -1, 3).permute(0, 2, 1)) + rays_world = rays_world.permute(0, 2, 1).reshape(N, H, W, 3) + + # Transform ray origins. + ray_origins_homo = torch.zeros((N, H * W, 4), device=device) + ray_origins_homo[..., 3] = 1 + ray_origins_world = torch.bmm(cam2world_matrix, + ray_origins_homo.permute(0, 2, 1)) + ray_origins_world = ray_origins_world.permute(0, 2, 1)[..., :3] + ray_origins_world = ray_origins_world.reshape(N, H, W, 3) + + # Transform points. + points_camera = (rays_camera.unsqueeze(3) * + radii.unsqueeze(4)) # [N, H, W, K, 3] + points_camera_homo = torch.cat( + [points_camera, torch.ones((N, H, W, K, 1), device=device)], + dim=-1) # [N, H, W, K, 4] + points_world_homo = (cam2world_matrix @ + points_camera_homo.reshape(N, -1, 4).permute(0, 2, 1)) + points_world = points_world_homo.permute(0, 2, 1)[:, :, :3] + points_world = points_world.reshape(N, H, W, K, 3) + + return { + 'rays_world': rays_world, + 'points_camera': points_camera, + 'points_world': points_world, + } diff --git a/source/rendering/sat2density_transform_eg3d.py b/source/rendering/sat2density_transform_eg3d.py new file mode 100644 index 0000000000000000000000000000000000000000..956d9b7cd9353539f35017d13a807abf0bee5fd1 --- /dev/null +++ b/source/rendering/sat2density_transform_eg3d.py @@ -0,0 +1,241 @@ +import numpy as np +import torch,math +from PIL import Image +import torchvision +from easydict import EasyDict as edict + +import torch.nn.functional as F +import torch.nn as nn +import random +from einops import repeat, rearrange + +from source.rendering.point_sampler import perturb_points_per_ray +from source.rendering.aabb import intersect_aabb_end +from source.rendering.transform_perspective import compose_rotmat + + +def get_normal_coord(W, H, device='cpu'): + ''' + Standard equirectangular panorama coordinate normalization. + W: panorama width + H: panorama height + device: target device, usually `cpu` or `cuda` + Returns: + normalized_coords: tensor with shape (W, H, 3) + ''' + # Create linear coordinates from 0 to W-1 and 0 to H-1. + y = torch.linspace(0, W - 1, W, device=device) + x = torch.linspace(0, H - 1, H, device=device) + + # Build the mesh grid. + Y, X = torch.meshgrid(y, x, indexing='ij') # Y: (W, H), X: (W, H) + + # Convert the grid to longitude and latitude. + phi = -(Y / (W - 1) - 0.5) * 2 * math.pi + (math.pi / 2) # longitude in [-pi, pi] + theta = -(0.5 - X / (H - 1)) * math.pi # latitude in [-pi/2, pi/2] + + # Compute normalized 3D coordinates. + cos_theta = torch.cos(theta) + sin_theta = torch.sin(theta) + cos_phi = torch.cos(phi) + sin_phi = torch.sin(phi) + + normalized_coords = torch.stack([ + cos_theta * cos_phi, # x axis + sin_theta, # y axis + cos_theta * sin_phi # z axis + ], dim=2) # Shape: (W, H, 3) + # reshape to (H, W, 3) + normalized_coords = normalized_coords.permute(1, 0, 2) + + return normalized_coords + + + +def get_original_coord(W,H,full=True,c2w=None): + ''' + W: width of pano + H: height of pano + if dataset is CVACT, ful=True, return the original coordinate of CVACT + if dataset is CVUSA, ful=False, + fill = False only used for CVUSA dataset + ''' + normalized_coords = get_normal_coord(W,H) + + + if c2w is None: + RollPitchYaw = [0,0,0] + R_c2w = compose_rotmat(RollPitchYaw[0], RollPitchYaw[1], RollPitchYaw[2]) + # to torch and then to the devidece of normalized_coords + # if numpy, to torch + if isinstance(R_c2w, np.ndarray): + R_c2w = torch.from_numpy(R_c2w).to(normalized_coords.device).float() + ray_directions = torch.einsum('ij,hwj->hwi', R_c2w, normalized_coords) + # ray_directions = np.einsum('ij,hwj->hwi', R_c2w, normalized_coords) # [H, W, 3] + + # Normalize ray directions by torch ops + ray_directions = ray_directions / torch.norm(ray_directions, dim=-1, keepdim=True) + # ray_directions = ray_directions / np.linalg.norm(ray_directions, axis=-1, keepdims=True) + return ray_directions + +class Point_sampler_pano(torch.nn.Module): + # designed for street view panorama image + def __init__(self, + pano_direction, + sample_total_length=None, + num_points=300, + perturbation_strategy = 'uniform', + aabb_strict =False, + data_type = None, + ): + super().__init__() + self.sample_total_length = np.sqrt(1.5**2+1.5**2+1.9**2) + + + self.pano_direction = pano_direction + self.num_points = num_points + if not aabb_strict: + self.sample_len = ((torch.arange(self.num_points)+1)*(self.sample_total_length/self.num_points)).cuda().float() + + self.voxel_low = -1 + self.voxel_max = 1 + + self.perturbation_strategy = perturbation_strategy + self.aabb_strict = aabb_strict + + @torch.no_grad() + def forward(self, + batch_size, + position=None, + ): + origin_opensfm = position[:,None,None,:].cuda() # w -h z + pano_direction = self.pano_direction[...,None].cuda() # b h w c # in opensfm coordinate + output = edict() + + H,W = pano_direction.shape[1],pano_direction.shape[2] + + + rays_world = repeat(pano_direction, '1 h w c 1 -> b h w c', b=batch_size ) + ray_origins = repeat(origin_opensfm, 'b 1 1 c -> b h w c', h=H, w=W ) + + if self.aabb_strict: + sample_total_length = intersect_aabb_end(ray_origins,rays_world,min=0,max=self.sample_total_length) + sample_total_length = rearrange(sample_total_length, '(b h w) -> b h w 1', b=batch_size, h = H, w = W ) + output.radii_raw = (torch.arange(self.num_points)+1)[None,None,None,:].to(sample_total_length.device) * (sample_total_length/self.num_points) + else: + depth = self.sample_len[None,None,None,None,:] + output.radii_raw = repeat(depth, '1 1 1 1 k -> b h w k', b=batch_size, h = H, w = W ) # (1,h,w,1,k) -> (batch_size, h, w, k) + output.radii = perturb_points_per_ray(output.radii_raw,strategy=self.perturbation_strategy) + sample_point = ray_origins.unsqueeze(-1) + rays_world.unsqueeze(-1) * output.radii.unsqueeze(-2) + + + + output.points_world = rearrange(sample_point, 'b h w c k -> b h w k c').clone() + output.ray_origins = ray_origins.clone() + output.ray_origins[...,1] = -output.ray_origins[...,1] + output.rays_world = rays_world.clone() + output.rays_world[...,1] = -output.rays_world[...,1] + output.points_world[...,1] = -output.points_world[...,1] + return output + + + + + +def get_sat_ori(resolution,position_scale_factor=1): + y_range = (torch.arange(resolution,dtype=torch.float32,)+0.5)/(0.5*resolution)-1 + x_range = (torch.arange(resolution,dtype=torch.float32,)+0.5)/(0.5*resolution)-1 + Y,X = torch.meshgrid(y_range,x_range) + Y = Y*position_scale_factor + X = X*position_scale_factor + Z = torch.ones_like(Y) # z=1 means the highest position in coordinate, in dimension 1 + xy_grid = torch.stack([X,Z,Y],dim=-1)[None,:,:] + return xy_grid + + +class Point_sampler_ortho(torch.nn.Module): + ''' + point sampler designed for ortho image, + + + ''' + def __init__(self, + num_points, + resolution=256, + perturbation_strategy = 'uniform', + position_scale_factor = 1, + render_size = 128, + ): + super().__init__() + self.perturbation_strategy = perturbation_strategy + # not used any more + self.resolution = resolution + self.sat_ori = get_sat_ori(self.resolution,position_scale_factor)[...,None] + self.sat_dir = torch.tensor([0,-1,0])[None,None,None,:,None] + self.sample_total_length = 2 + self.num_points = num_points + self.sample_len = ((torch.arange(self.num_points)+1)*(self.sample_total_length/self.num_points)) + self.render_size = render_size + + @torch.no_grad() + def forward(self, + batch_size, + random_crop=True, + crop_type=None, + ): + depth = self.sample_len[None,None,None,None,:].cuda().float() + sat_dir = self.sat_dir.cuda() + # sample_point = self.sat_ori + self.sat_dir * depth + output = edict() + if random_crop: + if crop_type == 'crop': + assert self.render_size < self.resolution + start_h = random.randint(0,self.resolution-self.render_size-1) + start_w = random.randint(0,self.resolution-self.render_size-1) + output.idx = [start_h,start_w] + sat_ori = self.sat_ori[:,start_h:start_h+self.render_size,start_w:start_w+self.render_size,:] + elif crop_type == 'resize': + sat_ori = rearrange(self.sat_ori,'b h w c 1 -> b c h w') + sat_ori = F.interpolate(sat_ori,scale_factor=0.5,mode='bilinear') + sat_ori = rearrange(sat_ori,'b c h w -> b h w c 1') + else: + raise NotImplementedError + else: + sat_ori = self.sat_ori + self.render_size = self.resolution + assert self.render_size == self.resolution + sat_ori = sat_ori.cuda() + + # sat_ori = self.position_scale_factor * sat_ori + # output.points_world = repeat(grid, '1 k h w c -> b h w k c', b=batch_size) + output.rays_world = repeat(sat_dir, '1 1 1 c 1 -> b h w c', b=batch_size, h = self.render_size, w = self.render_size )[...,[0,2,1]] # (1,h,w,3,1) -> (batch_size, h, w, 3) + output.radii_raw = repeat(depth, '1 1 1 1 k -> b h w k', b=batch_size, h = self.render_size, w = self.render_size ) # (1,h,w,1,k) -> (batch_size, h, w, k, 1) + output.ray_origins = repeat(sat_ori, '1 h w c 1 -> b h w c',b=batch_size)[...,[0,2,1]] + output.radii = perturb_points_per_ray(output.radii_raw,strategy=self.perturbation_strategy) + + sample_point = sat_ori + sat_dir * output.radii.unsqueeze(-2) + + grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] # has a change back, from height in the second dimension to height in the last dimension + # grid[...,2] = ((grid[...,2]-self.voxel_low)/(self.voxel_max-self.voxel_low))*2-1 + # grid = grid.float() + output.points_world = rearrange(grid, 'b k h w c -> b h w k c') + return output + + + + + +# class RGB_Reprerenter(torch.nn.Module): +# def __init__(self, +# ): +# super().__init__() + +# def forward(self, +# points, +# image, +# ): +# point_h_w = points[...,0:2].unsqueeze(2) # b, N, 1, 2 +# rgb_feature = F.grid_sample(image,point_h_w).squeeze(-1).permute(0,2,1) # b, C, N, 1 -> b, N, C +# return rgb_feature + + diff --git a/source/rendering/transform_perspective.py b/source/rendering/transform_perspective.py new file mode 100644 index 0000000000000000000000000000000000000000..20b2b6feec0160e0a25c1bb6edc1e63149b29a3b --- /dev/null +++ b/source/rendering/transform_perspective.py @@ -0,0 +1,149 @@ +import numpy as np +from scipy.spatial.transform import Rotation +import torch +from einops import repeat, rearrange +from easydict import EasyDict as edict +import torch.nn.functional as F + +from source.rendering.aabb import intersect_aabb_end +from source.rendering.point_sampler import perturb_points_per_ray + +def decompose_rotmat(R_c2w): + R_cv2xyz = Rotation.from_euler("X", -90, degrees=True) + rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv() + roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True) + return roll, pitch, yaw + +def normalize_angles(angles): + """Normalize angles to be within the range [-180, 180] degrees.""" + return (np.array(angles) + 180) % 360 - 180 + +def compose_rotmat(roll, pitch, yaw): + R_cv2xyz = Rotation.from_euler("X", -90, degrees=True) + rot_c2w = Rotation.from_euler("YXZ", [roll, pitch, yaw], degrees=True) + + rot_w2c = R_cv2xyz.inv() * rot_c2w + + return rot_w2c.inv().as_matrix() + + +def fov_size2intrinsics(fov, img_size): + """Converts field of view size to camera intrinsics.""" + if isinstance(fov, (int, float)): + fov = [fov, fov] + fov_x = np.deg2rad(fov[0]) + fov_y = np.deg2rad(fov[1]) + fx = (img_size[0] / 2) / np.tan(fov_x / 2) + fy = (img_size[1] / 2) / np.tan(fov_y / 2) + return np.array([[fx, 0, img_size[0] / 2], + [0, fy, img_size[1] / 2], + [0, 0, 1]]) + +def from_Euler_and_position_to_c2w(roll_pitch_yaw, position): + roll, pitch, yaw = roll_pitch_yaw + rot_c2w = compose_rotmat(roll, pitch, yaw) + T_c2w = np.eye(4) + T_c2w[:3, :3] = rot_c2w + if isinstance(position, torch.Tensor): + position = position.cpu().numpy() + T_c2w[:3, 3] = position + return T_c2w + +class PointSamplerPerspective(torch.nn.Module): + def __init__(self, num_points,aabb_strict=True,perturbation_strategy = 'uniform',render_size=[128,128]): + super().__init__() + """ + render_size: [H,W] + num_points: number of points to sample along each ray + aabb_strict: whether to use strict AABB for sampling + perturbation_strategy: strategy for perturbing points along the ray + """ + self.aabb_strict = aabb_strict, + self.sample_total_length = np.sqrt(1.5**2+1.5**2+1.9**2) + self.num_points = num_points + self.perturbation_strategy = perturbation_strategy + self.render_size = render_size + + + @torch.no_grad() + def forward(self, intrinsics, c2w): + # c2w: B x 4 x 4 + # intrinsics: B x 3 x 3 + # return: + # output = edict() + # output.rays_world: B x H x W x C # direction of the rays + # output.radii_raw: B x H x W x K + # output.radii: B x H x W x K + # output.ray_origins: B x H x W x C # origin of the rays + # output.points_world: B x H x W x K x C + + batch_size = c2w.shape[0] + t = c2w[:, :3, 3].clone() + output = edict() + output.ray_origins = repeat(t, 'b c -> b h w c', h=self.render_size[0] , w=self.render_size[1]).cuda() + output.ray_origins = output.ray_origins.clone() # w -h z + output.rays_world = compute_ray_directions(c2w.cuda(), intrinsics.cuda(), self.render_size[0], self.render_size[1]) + + if self.aabb_strict: + # from b c to (b h w) c + # origin_for_aabb = repeat(output.ray_origins, 'b c -> b h w c', h = H, w = W) + # from b h w c to (b h w) c + # pano_direction_for_aabb = repeat(output.rays_world, 'b h w c -> b h w c', h = H, w = W) + sample_total_length = intersect_aabb_end(output.ray_origins,output.rays_world,min=0,max=self.sample_total_length) + sample_total_length = rearrange(sample_total_length, '(b h w) -> b h w 1', b=batch_size, h = self.render_size[0], w = self.render_size[1] ) + output.radii_raw = (torch.arange(self.num_points)+1)[None,None,None,:].to(sample_total_length.device) * (sample_total_length/self.num_points) + else: + raise NotImplementedError + output.radii = perturb_points_per_ray(output.radii_raw,strategy=self.perturbation_strategy) + sample_point = output.ray_origins.unsqueeze(-1) + output.rays_world.unsqueeze(-1) * output.radii.unsqueeze(-2) + output.points_world = rearrange(sample_point, 'b h w c k -> b h w k c') + # process_from w -h z to w h z + output.ray_origins[...,1] = -output.ray_origins[...,1] + output.rays_world[...,1] = -output.rays_world[...,1] + output.points_world[...,1] = -output.points_world[...,1] + return output + + + # return output + + + +def generate_pixel_coordinates(H, W): + """ + Generate pixel coordinates grid on the image plane. + + Parameters: + - H: Image height + - W: Image width + + Returns: + - pixel_coords: Pixel coordinates grid with shape [H, W, 3] + """ + y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') + # to current device + pixel_coords = torch.stack([x, y, torch.ones_like(x)], dim=-1).float().to(torch.cuda.current_device() if torch.cuda.is_available() else 'cpu') + return pixel_coords + +def compute_ray_directions(camera2world, intrinsics, H, W): + """ + Compute ray directions. + + Parameters: + - camera2world: Camera-to-world transformation matrix with shape [B, 4, 4] + - intrinsics: Intrinsic matrix with shape [B, n, n] + + Returns: + - ray_directions: Ray directions with shape [B, H, W, 3] + """ + B = camera2world.shape[0] + pixel_coords = generate_pixel_coordinates(H, W) # [H, W, 3] + pixel_coords = pixel_coords.unsqueeze(0).expand(B, -1, -1, -1) # [B, H, W, 3] + + inv_intrinsics = torch.inverse(intrinsics) # [B, n, n] + normalized_coords = torch.einsum('bij,bhwj->bhwi', inv_intrinsics, pixel_coords) # [B, H, W, 3] + + ray_directions = torch.einsum('bij,bhwj->bhwi', camera2world[:, :3, :3], normalized_coords) # [B, H, W, 3] + + ray_directions = ray_directions / torch.norm(ray_directions, dim=-1, keepdim=True) + + return ray_directions diff --git a/source/rendering/utils.py b/source/rendering/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b5883ad4b1caba513c2e0e818007438247b83b --- /dev/null +++ b/source/rendering/utils.py @@ -0,0 +1,586 @@ +# python3.8 +"""Utility functions used for rendering module.""" + +import math +import torch +import torch.nn.functional as F +import numpy as np + +EPS = 1e-6 + + +def sample_importance(radial_dists, + weights, + num_importance, + smooth_weights=False): + """Implements importance sampling, which is the crucial step in hierarchical + sampling of NeRF. Hierarchical volume sampling mainly includes the following + steps: + + 1. Sample a set of `Nc` points using stratified sampling. + 2. Evaluate the 'coarse' network at locations of these points as described + in Eq. (2) & (3) in the paper. + 3. Normalize the output weights to get a piecewise-constant PDF (probability + density function) along the ray. + 4. Sample a second set of `Nf` points from this distribution using inverse + transform sampling. + + And importance sampling refers to step 4 specifically. + + Code is borrowed from: + + https://github.com/NVlabs/eg3d/blob/main/eg3d/training/volumetric_rendering/renderer.py + + Args: + radial_dists: Radial distances, with shape [N, R, K, 1] + weights: Per-point weight for integral, with shape [N, R, K, 1]. + num_importance: Number of points for importance sampling. + smooth_weights: Whether to smooth weights. Defaults to `False`. + + Returns: + importance_radial_dists: Radial distances of importance sampled points + along rays. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = radial_dists.shape + radial_dists = radial_dists.reshape(batch_size * num_rays, + samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) + 1e-5 + + # Smooth weights. + if smooth_weights: + weights = F.max_pool1d(weights.unsqueeze(1).float(), + 2, 1, padding=1) + weights = F.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + radial_dists_mid = 0.5 * (radial_dists[:, :-1] + radial_dists[:, 1:]) + importance_radial_dists = sample_pdf(radial_dists_mid, weights[:, 1:-1], + num_importance) + importance_radial_dists = importance_radial_dists.detach().reshape( + batch_size, num_rays, num_importance, 1) + + return importance_radial_dists + + +def sample_pdf(bins, weights, num_importance, det=False, eps=1e-5): + """Sample `num_importance` samples from `bins` with distribution defined + by `weights`. Borrowed from: + + https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py + + Args: + bins: Bins distributed along rays, with shape (N * R, K - 1). + weights: Per-point weight for integral, with shape [N * R, K]. + num_importance: The number of samples to draw from the distribution. + det: Deterministic or not. Defaults to `False`. + eps: A small number to prevent division by zero. Defaults to `1e-5`. + + Returns: + samples: The sampled samples. + """ + n_rays, n_samples_ = weights.shape + weights = weights + eps + # Prevent division by zero (don't do inplace op!). + pdf = weights / torch.sum(weights, -1, + keepdim=True) # (n_rays, n_samples_) + cdf = torch.cumsum(pdf, -1) # (n_rays, N_samples), + # Cumulative distribution function. + cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], + -1) # (n_rays, n_samples_+1) + + if det: + u = torch.linspace(0, 1, num_importance, device=bins.device) + u = u.expand(n_rays, num_importance) + else: + u = torch.rand(n_rays, num_importance, device=bins.device) + u = u.contiguous() + + indices = torch.searchsorted(cdf, u) + below = torch.clamp_min(indices - 1, 0) + above = torch.clamp_max(indices, n_samples_) + + indices_sampled = torch.stack([below, above], -1).view(n_rays, + 2 * num_importance) + cdf_g = torch.gather(cdf, 1, indices_sampled) + cdf_g = cdf_g.view(n_rays, num_importance, 2) + bins_g = torch.gather(bins, 1, indices_sampled).view(n_rays, + num_importance, 2) + + # `denom` equals 0 means a bin has weight 0, in which case it will not be + # sampled anyway, therefore any value for it is fine (set to 1 here). + denom = cdf_g[..., 1] - cdf_g[..., 0] + denom[denom < eps] = 1 + + samples = (bins_g[..., 0] + (u - cdf_g[..., 0]) / + denom * (bins_g[..., 1] - bins_g[..., 0])) + + return samples + + +def unify_attributes(radial_dists1, + colors1, + densities1, + radial_dists2, + colors2, + densities2, + points1=None, + points2=None): + """Unify attributes of point samples according to their radial distances. + + Args: + radial_dists1: Radial distances of the first pass, with shape + [N, R, K1, 1]. + colors1: Colors or features of the first pass, with shape [N, R, K1, C]. + densities1: Densities of the first pass, with shape [N, R, K1, 1]. + radial_dists2: Radial distances of the second pass, with shape + [N, R, K2, 1]. + colors2: Colors or features of the second pass, with shape + [N, R, K2, C]. + densities2: Densities of the second pass, with shape [N, R, K2, 1]. + points1 (optional): Point coordinates of the first pass, + with shape [N, R, K1, 3]. + points2 (optional): Point coordinates of the second pass, + with shape [N, R, K2, 3]. + + Returns: + all_radial_dists: Unified radial distances, with shape [N, R, K1+K2, 1]. + all_colors: Unified colors or features, with shape [N, R, K1+K2, C]. + all_densities: Unified densities, with shape [N, R, K1+K2, 1]. + """ + all_radial_dists = torch.cat([radial_dists1, radial_dists2], dim=-2) + all_colors = torch.cat([colors1, colors2], dim=-2) + all_densities = torch.cat([densities1, densities2], dim=-2) + + _, indices = torch.sort(all_radial_dists, dim=-2) + all_radial_dists = torch.gather(all_radial_dists, -2, indices) + all_colors = torch.gather( + all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, + indices.expand(-1, -1, -1, 1)) + + if points1 is not None and points2 is not None: + all_points = torch.cat([points1, points2], dim=-2) + all_points = torch.gather( + all_points, -2, indices.expand(-1, -1, -1, all_points.shape[-1])) + return all_radial_dists, all_colors, all_densities, all_points + + return all_radial_dists, all_colors, all_densities + + +def depth2pts_outside(ray_o, ray_d, depth): + """Compute point coordinates in the inverted sphere parameterization. + + This function is borrowed from the official code of NeRF++: + + https://github.com/Kai-46/nerfplusplus + + Args: + ray_o (torch.Tensor): Ray origins, with shape [N, R, K, 3]. + ray_d (torch.Tensor): Ray directions, with shape [N, R, K, 3]. + depth (torch.Tensor): Inverse of distance to sphere origin, + with shape [N, R, K]. + + Returns: + pts (torch.Tensor): Sampled points with inversed sphere parametrization, + denoted as (x', y', z', 1/r), with shape [N, R, K, 4]. + depth_real (torch.Tensor): Depth in Euclidean space. + """ + + # Note: d1 becomes negative if this mid point is behind camera. + d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) + p_mid = ray_o + d1.unsqueeze(-1) * ray_d + p_mid_norm = torch.norm(p_mid, dim=-1) + ray_d_cos = 1. / torch.norm(ray_d, dim=-1) + d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos + p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d + + rot_axis = torch.cross(ray_o, p_sphere, dim=-1) + rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) + phi = torch.asin(p_mid_norm) + theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1] + rot_angle = (phi - theta).unsqueeze(-1) # [..., 1] + + # Rotate p_sphere using Rodrigues formula: + # https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula + p_sphere_new = ( + p_sphere * torch.cos(rot_angle) + + torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + + rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * + (1. - torch.cos(rot_angle))) + p_sphere_new = p_sphere_new / torch.norm( + p_sphere_new, dim=-1, keepdim=True) + pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1) + + # Calculate conventional depth. + depth_real = 1. / (depth + EPS) * torch.cos(theta) * ray_d_cos + d1 + + return pts, depth_real + + +class PositionEncoder(torch.nn.Module): + """Implements the class for positional encoding.""" + + def __init__(self, + input_dim, + max_freq_log2, + num_freqs, + log_sampling=True, + factor=1.0, + include_input=True, + periodic_fns=(torch.sin, torch.cos)): + """Initializes with basic settings. + + Args: + input_dim: Dimension of input to be embedded. + max_freq_log2: `log2` of max freq; min freq is 1 by default. + num_freqs: Number of frequency bands. + log_sampling: If True, frequency bands are linerly sampled in + log-space. + factor: Factor of the frequency bands. + include_input: If True, raw input is included in the embedding. + Defaults to True. + periodic_fns: Periodic functions used to embed input. + Defaults to (torch.sin, torch.cos). + """ + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.out_dim = 0 + if self.include_input: + self.out_dim += self.input_dim + + self.out_dim += self.input_dim * num_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2.**torch.linspace(0., max_freq_log2, + num_freqs) * factor + else: + self.freq_bands = torch.linspace(2.**0., 2.**max_freq_log2, + num_freqs) * factor + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input): + """Forward function of positional encoding. + + Args: + input: Input tensor with shape [..., input_dim] + + Returns: + output: Output tensor with shape [..., out_dim] + """ + output = [] + if self.include_input: + output.append(input) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + output.append(p_fn(input * freq)) + output = torch.cat(output, dim=-1) + + return output + + def get_out_dim(self): + return self.out_dim + + +class GaussianCameraPoseSampler: + """ + Samples pitch and yaw from a Gaussian distribution and returns a camera + pose. Camera is specified as looking at the origin. If horizontal and + vertical stddev (specified in radians) are zero, gives a deterministic + camera pose with yaw=horizontal_mean, pitch=vertical_mean. The coordinate + system is specified with y-up, z-forward, x-left. Horizontal mean is the + azimuthal angle (rotation around y axis) in radians, vertical mean is the + polar angle (angle from the y axis) in radians. A point along the z-axis + has azimuthal_angle=0, polar_angle=pi/2. + + Example: + For a camera looking at the origin with the camera at position [0, 0, 1]: + cam2world = GaussianCameraPoseSampler.sample(math.pi/2, + math.pi/2, + radius=1) + """ + + @staticmethod + def sample(horizontal_mean, + vertical_mean, + horizontal_stddev=0, + vertical_stddev=0, + radius=1, + batch_size=1, + device='cpu'): + h = torch.randn((batch_size, 1), + device=device) * horizontal_stddev + horizontal_mean + v = torch.randn( + (batch_size, 1), device=device) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2 * v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi - + theta) + camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi - + theta) + camera_origins[:, 1:2] = radius * torch.cos(phi) + + forward_vectors = normalize_vecs(-camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + + +class LookAtPoseSampler: + """ + Same as GaussianCameraPoseSampler, except the + camera is specified as looking at 'lookat_position', a 3-vector. + + Example: + For a camera pose looking at the origin with the camera at position [ + 0, 0, 1]: + cam2world = LookAtPoseSampler.sample( + math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1) + """ + + @staticmethod + def sample(horizontal_mean, + vertical_mean, + lookat_position, + horizontal_stddev=0, + vertical_stddev=0, + radius=1, + batch_size=1, + device='cpu'): + h = torch.randn((batch_size, 1), + device=device) * horizontal_stddev + horizontal_mean + v = torch.randn( + (batch_size, 1), device=device) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2 * v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi - + theta) + camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi - + theta) + camera_origins[:, 1:2] = radius * torch.cos(phi) + + # forward_vectors = normalize_vecs(-camera_origins) + forward_vectors = normalize_vecs(lookat_position - camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + + +class UniformCameraPoseSampler: + """ + Same as GaussianCameraPoseSampler, except the pose is sampled from a + uniform distribution with range +-[horizontal/vertical]_stddev. + + Example: + For a batch of random camera poses looking at the origin with yaw sampled + from [-pi/2, +pi/2] radians: + + cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, + math.pi/2, + horizontal_stddev=math.pi/2, + radius=1, + batch_size=16) + """ + + @staticmethod + def sample(horizontal_mean, + vertical_mean, + horizontal_stddev=0, + vertical_stddev=0, + radius=1, + batch_size=1, + device='cpu'): + h = (torch.rand((batch_size, 1), device=device) * 2 - + 1) * horizontal_stddev + horizontal_mean + v = (torch.rand((batch_size, 1), device=device) * 2 - + 1) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2 * v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi - + theta) + camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi - + theta) + camera_origins[:, 1:2] = radius * torch.cos(phi) + + forward_vectors = normalize_vecs(-camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + + +def create_cam2world_matrix(forward_vector, origin): + """ + Takes in the direction the camera is pointing and the camera origin and + returns a cam2world matrix. Works on batches of forward_vectors, origins. + Assumes y-axis is up and that there is no camera roll. + """ + + forward_vector = normalize_vecs(forward_vector) + up_vector = torch.tensor([0, 1, 0], + dtype=torch.float, + device=origin.device).expand_as(forward_vector) + + right_vector = -normalize_vecs( + torch.cross(up_vector, forward_vector, dim=-1)) + up_vector = normalize_vecs( + torch.cross(forward_vector, right_vector, dim=-1)) + + rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat( + forward_vector.shape[0], 1, 1) + rotation_matrix[:, :3, :3] = torch.stack( + (right_vector, up_vector, forward_vector), axis=-1) + + translation_matrix = torch.eye(4, device=origin.device) + translation_matrix = translation_matrix.unsqueeze(0).repeat( + forward_vector.shape[0], 1, 1) + translation_matrix[:, :3, 3] = origin + cam2world = (translation_matrix @ rotation_matrix)[:, :, :] + assert (cam2world.shape[1:] == (4, 4)) + return cam2world + + +def compute_camera_origins(angles, radius): + yaw = angles[:, [0]] # [batch_size, 1] + pitch = angles[:, [1]] # [batch_size, 1] + + assert yaw.ndim == 2, f"Wrong shape: {yaw.shape}, {pitch.shape}" + assert yaw.shape == pitch.shape, f"Wrong shape: {yaw.shape}, {pitch.shape}" + + origins = torch.zeros((yaw.shape[0], 3), device=yaw.device) + origins[:, [0]] = radius * torch.sin(pitch) * torch.cos(yaw) + origins[:, [2]] = radius * torch.sin(pitch) * torch.sin(yaw) + origins[:, [1]] = radius * torch.cos(pitch) + + return origins + + +def compute_cam2world_matrix(camera_angles, radius): + """ + Takes in the direction the camera is pointing and the camera origin and + returns a cam2world matrix. + + Note: `camera_angles` should be provided in the "yaw/pitch/roll" format, + and with shape [batch_size, 3] + """ + camera_origins = compute_camera_origins(camera_angles, + radius) # [batch_size, 3] + forward_vector = normalize_vecs(-camera_origins) # [batch_size, 3] + batch_size = forward_vector.shape[0] + forward_vector = normalize_vecs(forward_vector) + up_vector = torch.tensor( + [0, 1, 0], dtype=torch.float, + device=forward_vector.device).expand_as(forward_vector) + left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, + dim=-1)) + up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, + dim=-1)) + + rotation_matrix = torch.eye( + 4, device=forward_vector.device).unsqueeze(0).repeat(batch_size, 1, 1) + rotation_matrix[:, :3, :3] = torch.stack( + (-left_vector, up_vector, -forward_vector), axis=-1) + + translation_matrix = torch.eye( + 4, device=forward_vector.device).unsqueeze(0).repeat(batch_size, 1, 1) + translation_matrix[:, :3, 3] = camera_origins + + cam2world = translation_matrix @ rotation_matrix + + return cam2world + + +def FOV_to_intrinsics(fov_degrees, device='cpu'): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, + specified in degrees. Note the intrinsics are returned as normalized by + image size, rather than in pixel units. Assumes principal point is at image + center. + """ + + focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414)) + intrinsics = torch.tensor( + [[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], + device=device) + return intrinsics + + +def normalize_vecs(vectors, dim=-1): + """Normalize vectors.""" + return vectors / (torch.norm(vectors, dim=dim, keepdim=True)) + + +def dividable(n, k=2): + if k == 2: + for i in range(int(math.sqrt(n)), 0, -1): + if n % i == 0: + break + return i, n // i + elif k == 3: + for i in range(int(float(n) ** (1/3)), 0, -1): + if n % i == 0: + b, c = dividable(n // i, 2) + return i, b, c + else: + raise NotImplementedError + + +def create_voxel(N=256, voxel_corner=[0, 0, 0], voxel_length=2.0,position_scale_factor=1): + """Creates a voxel grid. + + Args: + N (int): Number of points in each side of the generated voxels. + Defaults to 256. + voxel_corner (list): Corner coordinate of the voxel, which represents + (bottom, left, down) of the voxel. Defaults to [0, 0, 0]. + voxel_length (float): Side length of the voxel. Defaults to 2.0. + + Returns: + A dictionary, containing: + - `voxel_grid`: voxel grid, with shape [1, N * N * N, 3]. + - `voxel_origin`: origin of the voxel grid, with shape [3]. + - `voxel_size`: voxel grid size, i.e. the distance between two + adjacent points in the voxel grid. + """ + voxel_origin = np.array(voxel_corner) - voxel_length / 2 + voxel_size = voxel_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + grid = torch.zeros(N ** 3, 3) + + # Get the x, y, z index of each point in the grid. + grid[:, 2] = overall_index % N + grid[:, 1] = (overall_index.float() / N) % N + grid[:, 0] = ((overall_index.float() / N) / N) % N + + # Get the x, y, z coordinate of each point in the grid. + grid[:, 0] = (grid[:, 0] * voxel_size) + voxel_origin[0] + grid[:, 1] = (grid[:, 1] * voxel_size) + voxel_origin[1] + grid[:, 2] = (grid[:, 2] * voxel_size) + voxel_origin[2] + grid[:, 0] = grid[:, 0] * position_scale_factor + grid[:, 1] = grid[:, 1] * position_scale_factor + voxel = { + 'voxel_grid': grid.unsqueeze(0), + 'voxel_origin': voxel_origin, + 'voxel_size': voxel_size + } + + return voxel \ No newline at end of file diff --git a/source/sky_histogram.py b/source/sky_histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..a52d2f4cce2efb32258a9179d49dbfa9b8aaa00d --- /dev/null +++ b/source/sky_histogram.py @@ -0,0 +1,37 @@ +"""Shared utility for computing sky-colour histograms from a panorama image.""" + +import numpy as np +import torch + + +def compute_sky_histogram(sky_image_np: np.ndarray, + hist_range: tuple = (-1, 1), + bins: int = 100, + skip_bins: int = 10) -> np.ndarray: + """Compute a normalised colour histogram from a masked sky image. + + Parameters + ---------- + sky_image_np : np.ndarray + Sky image array of shape ``(C, H, W)`` (single image, no batch dim). + Pixel values should lie within *hist_range*. + hist_range : tuple + ``(min, max)`` passed to ``np.histogram``. + bins : int + Number of histogram bins. + skip_bins : int + Number of leading bins to discard (removes near-zero / masked pixels). + + Returns + ------- + np.ndarray + Concatenated normalised histogram of shape ``((bins - skip_bins) * C,)``. + """ + channel_histograms = [] + for channel in sky_image_np: + histo = np.histogram(channel.flatten(), bins=bins, range=hist_range)[0] + histo = histo[skip_bins:] + if histo.sum() != 0: + histo = histo / histo.sum() + channel_histograms.append(histo) + return np.concatenate(channel_histograms) diff --git a/source/sr_module.py b/source/sr_module.py new file mode 100644 index 0000000000000000000000000000000000000000..66d112d2769e629d947999b4599e9565e372848e --- /dev/null +++ b/source/sr_module.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from source.vae_hacked import ResnetBlock + +class SuperresolutionHybrid2X(nn.Module): + def __init__(self, in_channels, out_channels,padding_mode='zeros',v2=False): + super().__init__() + self.out_channels = out_channels + self.block0 = SynthesisBlockNoUp(in_channels, 128,padding_mode=padding_mode) + self.block1 = SynthesisBlock(128, out_channels, is_last=True,padding_mode=padding_mode,v2=v2) + + def forward(self, x): + img = x[:, :3, :, :].clone() + if self.out_channels == 4: + opa = x[:, -1:, :, :].unsqueeze(1) + else: + opa = None + x = self.block0(x) + x = self.block1(x,img,opa) + return x + +class SynthesisBlockNoUp(torch.nn.Module): + # conv group + def __init__(self, in_channels, out_channels,padding_mode): + super().__init__() + self.conv0 = ConvBlock(in_channels=in_channels, out_channels=128,padding_mode= padding_mode) + self.conv1 = ConvBlock(in_channels=128, out_channels=out_channels,padding_mode= padding_mode) + self.skip_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0,padding_mode= padding_mode) + + def forward(self, x): + input = self.skip_conv(x) + x = self.conv0(x) + x = self.conv1(x) + return x+input + +class SynthesisBlock(torch.nn.Module): + # need 2x upsample + # need toRGB layer (1x1 conv) to get 3 channels + def __init__(self, in_channels, out_channels,padding_mode, is_last=True, v2=False): + super().__init__() + self.conv0 = ConvBlock(in_channels=in_channels, out_channels = 128,padding_mode= padding_mode) + self.conv1 = ConvBlock(in_channels=128, out_channels = 128,padding_mode= padding_mode) + self.is_last = is_last + self.v2 = v2 + self.toRGB = nn.Conv2d(in_channels=128, out_channels= out_channels, kernel_size=1, stride=1, padding=0,padding_mode= padding_mode, bias= not self.v2) + self.out_channels = out_channels + + def forward(self, x,img,opa = None): + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False,antialias=True) + img = F.interpolate(img, scale_factor=2, mode='bilinear', align_corners=False,antialias=True) + if opa is not None: + opa = F.interpolate(opa, scale_factor=2, mode='nearest', align_corners=False,antialias=True) + x = torch.cat([x,opa],dim=1) + x = self.conv0(x) + x = self.conv1(x) + if self.is_last: + x = self.toRGB(x) + if not self.v2: + x = torch.tanh(x) + img = img + x[:,:3] + # clamp + if opa is not None: + opa = opa + x[:,3:] + opa = torch.clamp(opa, 0, 1) + img = torch.clamp(img, 0, 1) + # x = torch.sigmoid(x) + if opa is not None: + return img,opa + return img + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels,padding_mode, kernel_size=3, stride=1, padding=1, if_act = True): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode) + self.norm = nn.InstanceNorm2d(out_channels) + self.act = nn.LeakyReLU(0.2) if if_act else nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x diff --git a/source/training_utils.py b/source/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c31c9b4abc171ae07196bb1eb5351ef9563d6a67 --- /dev/null +++ b/source/training_utils.py @@ -0,0 +1,34 @@ +import os +import shutil + + +def save_checkpoint(accelerator, args, global_step, logger): + """Save model checkpoint and enforce checkpoints_total_limit.""" + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + checkpoint_path = os.path.join(args.output_dir, removing_checkpoint) + if os.path.exists(checkpoint_path): + try: + shutil.rmtree(checkpoint_path) + logger.info(f'Directory "{checkpoint_path}" has been removed.') + except Exception as e: + logger.info(f'Error removing directory "{checkpoint_path}": {e}. Continuing with the next item.') + else: + logger.info(f'Directory "{checkpoint_path}" does not exist.') diff --git a/source/unet_hacked.py b/source/unet_hacked.py new file mode 100644 index 0000000000000000000000000000000000000000..a74e8b182664e81b58ae952333c2ad4382a08563 --- /dev/null +++ b/source/unet_hacked.py @@ -0,0 +1,1890 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + + +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat +import importlib + +class Linear(nn.Linear): + def forward(self, input): + return F.linear(input, self.weight.to(input.dtype), self.bias.to(input.dtype) if self.bias is not None else None) + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +SiLU = nn.SiLU(inplace=True) +# class SiLU(nn.Module): +# def forward(self, x): +# return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +# CrossAttn precision handling +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp16") + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = Linear(query_dim, inner_dim, bias=False) + self.to_k = Linear(context_dim, inner_dim, bias=False) + self.to_v = Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = Linear(query_dim, inner_dim, bias=False) + self.to_k = Linear(context_dim, inner_dim, bias=False) + self.to_v = Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class BasicTransformerBlock3D(BasicTransformerBlock): + + def forward(self, x, context=None, num_frames=1): + return checkpoint(self._forward, (x, context, num_frames), self.parameters(), False) + + def _forward(self, x, context=None, num_frames=1): + x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer3D(nn.Module): + ''' 3D self-attention ''' + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None, num_frames=1): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i], num_frames=num_frames) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, num_frames=1): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer3D): + x = layer(x, context, num_frames=num_frames) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_bf16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class MultiViewUNetModel(nn.Module): + """ + The full multi-view UNet model with attention, timestep embedding and camera embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + :param camera_dim: dimensionality of camera input. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_bf16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + camera_dim=None, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer3D( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer3D( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer3D( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + # self.cross_frames_attention_sizes = [32, 16, 8, 4] + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, num_frames=1, **kwargs): + """ + Apply the model to an input batch. + :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :param num_frames: a integer indicating number of frames for tensor reshaping. + :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). + """ + assert x.shape[0] % num_frames == 0, "[UNet] input batch size must be dividable by num_frames!" + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y.to(self.dtype)) + + h = x.to(self.dtype) + context = context.to(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context, num_frames=num_frames) + hs.append(h) + h = self.middle_block(h, emb, context, num_frames=num_frames) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, num_frames=num_frames) + h = h + return self.out(h) \ No newline at end of file diff --git a/source/vae_hacked.py b/source/vae_hacked.py new file mode 100644 index 0000000000000000000000000000000000000000..934a59be14d12812407366dbf07f6292abfbcf88 --- /dev/null +++ b/source/vae_hacked.py @@ -0,0 +1,1054 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from typing import Optional, Any + +from .unet_hacked import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + +class Linear(nn.Linear): + def forward(self, input): + return F.linear(input, self.weight.to(input.dtype), self.bias.to(input.dtype)) + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512,padding_mode="zeros"): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode=padding_mode) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if context is not None: + context = rearrange(context, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + # x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif attn_type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", with_attn=True,pano_pad=False,**ignorekwargs): + super().__init__() + # print with_attn + print(f"Decoder with attn: {with_attn}") + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + # in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + padding_mode="zeros" if not pano_pad else "circular") + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + if with_attn: + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + padding_mode="zeros" if not pano_pad else "circular")) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1, + padding_mode="zeros" if not pano_pad else "circular") + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + # middle + h = self.mid.block_1(h, temb) + if hasattr(self.mid, "attn_1"): + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + # print(h.std()) + if i_level != 0: + h = self.up[i_level].upsample(h) + # print(h.std()) + + # end + if self.give_pre_end: + return h + + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + +""" +for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + print(i_level, i_block, h.std()) + h = self.up[i_level].block[i_block](h, temb) + if h.isnan().any(): + print(i_level) + print(i_block) + break + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if h.isnan().any(): + print(i_level) + print(i_block) + break + if i_level != 0: + h = self.up[i_level].upsample(h) + if h.isnan().any(): + print(i_level) + print(i_block) + break + +""" + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +import torch +import torch.nn.functional as F + + +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self,z=None): + if z is None: + z = torch.randn(self.mean.shape).to(device=self.parameters.device) + x = self.mean + self.std * z + return x, z + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + +class AutoencoderKL(torch.nn.Module): + def __init__(self, + ddconfig, + embed_dim + ): + super().__init__() + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + def encode(self, x,z=None): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, extra_z=None, path='new'): + post_quant_conv = self.post_quant_conv if path == 'new' else self.old_post_quant_conv + decoder = self.decoder if path == 'new' else self.old_decoder + z = post_quant_conv(z) + if extra_z is not None: + z = torch.cat([z, extra_z], dim=1) + dec = decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x \ No newline at end of file diff --git a/source/xyz2thetaphi.py b/source/xyz2thetaphi.py new file mode 100644 index 0000000000000000000000000000000000000000..f60cd8ccf73437e534b2185576f8e6b1d9934e2f --- /dev/null +++ b/source/xyz2thetaphi.py @@ -0,0 +1,23 @@ +import torch +import numpy as np + + +def xyz2thetaphi(xyz): + """ + xyz: (N, ..., 3) tensor + """ + # Normalize the input tensor + xyz = xyz / torch.norm(xyz, dim=-1, keepdim=True) + + # Calculate theta and phi + theta = torch.acos(xyz[..., 2]) # data range [0, pi] + phi = torch.atan2(xyz[..., 1], xyz[..., 0]) # data range [-pi, pi] + + # to [-1,1] + theta = (theta / 3.141592653589793) * 2 - 1 + phi = phi / 3.141592653589793 + + # cat + thetaphi = torch.cat([theta.unsqueeze(-1), phi.unsqueeze(-1)], dim=-1) + + return thetaphi \ No newline at end of file