# Modified from https://antimatter15.com/splat from plyfile import PlyData import numpy as np import argparse from io import BytesIO def splat_to_numpy(file_path): with open(file_path, 'rb') as f: splat_data = f.read() splat_dtype = np.dtype([ ('position', np.float32, 3), ('scale', np.float32, 3), ('color', np.uint8, 4), ('rotation', np.uint8, 4) ]) splat_array = np.frombuffer(splat_data, dtype=splat_dtype) points = splat_array["position"] scales = splat_array["scale"] rots = (splat_array["rotation"]/255)*2 - 1 color = splat_array["color"]/255 return points, scales, rots.astype(np.float32), color.astype(np.float32) def ply_to_numpy(ply_file_path): plydata = PlyData.read(ply_file_path) vert = plydata["vertex"] sorted_indices = np.argsort( -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) / (1 + np.exp(-vert["opacity"])) ) buffer = BytesIO() positions = np.zeros((len(sorted_indices), 3), dtype=np.float32) scales = np.zeros((len(sorted_indices), 3), dtype=np.float32) rots = np.zeros((len(sorted_indices), 4), dtype=np.float32) colors = np.zeros((len(sorted_indices), 4), dtype=np.float32) for idx in sorted_indices: v = plydata["vertex"][idx] position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) scale = np.exp( np.array( [v["scale_0"], v["scale_1"], v["scale_2"]], dtype=np.float32, ) ) rot = np.array( [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], dtype=np.float32, ) SH_C0 = 0.28209479177387814 color = np.array( [ 0.5 + SH_C0 * v["f_dc_0"], 0.5 + SH_C0 * v["f_dc_1"], 0.5 + SH_C0 * v["f_dc_2"], 1 / (1 + np.exp(-v["opacity"])), ] ) positions[idx] = position scales[idx] = scale rots[idx] = rot colors[idx] = color return positions, scales, rots, colors def numpy_to_splat(positions, scales, rots, colors, output_path, file_type): buffer = BytesIO() if file_type == 'ply': for idx in range(len(positions)): position = positions[idx] scale = scales[idx] rot = rots[idx] color = colors[idx] buffer.write(position.tobytes()) buffer.write(scale.tobytes()) buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) buffer.write(((rot / np.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype(np.uint8).tobytes() ) splat_data = buffer.getvalue() with open(output_path, "wb") as f: f.write(splat_data) else: for idx in range(len(positions)): position = positions[idx] scale = scales[idx] rot = rots[idx] color = colors[idx] buffer.write(position.tobytes()) buffer.write(scale.tobytes()) buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) buffer.write(((rot / np.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype(np.uint8).tobytes() ) splat_data = buffer.getvalue() with open(output_path, "wb") as f: f.write(splat_data) return splat_data def main(): parser = argparse.ArgumentParser(description="Convert PLY files to SPLAT format.") parser.add_argument( "input_files", nargs="+", help="The input PLY files to process." ) parser.add_argument( "--output", "-o", default="output.splat", help="The output SPLAT file." ) args = parser.parse_args() for input_file in args.input_files: print(f"Processing {input_file}...") positions, scales, rotations, colors = ply_to_numpy(input_file) numpy_to_splat(positions, scales, rotations, colors, args.output) if __name__ == "__main__": main()