Sat3DGen / source /xyz2thetaphi.py
qian43's picture
Upload 115 files
874cec4 verified
import torch
import numpy as np
def xyz2thetaphi(xyz):
"""
xyz: (N, ..., 3) tensor
"""
# Normalize the input tensor
xyz = xyz / torch.norm(xyz, dim=-1, keepdim=True)
# Calculate theta and phi
theta = torch.acos(xyz[..., 2]) # data range [0, pi]
phi = torch.atan2(xyz[..., 1], xyz[..., 0]) # data range [-pi, pi]
# to [-1,1]
theta = (theta / 3.141592653589793) * 2 - 1
phi = phi / 3.141592653589793
# cat
thetaphi = torch.cat([theta.unsqueeze(-1), phi.unsqueeze(-1)], dim=-1)
return thetaphi