import torch import numpy as np def xyz2thetaphi(xyz): """ xyz: (N, ..., 3) tensor """ # Normalize the input tensor xyz = xyz / torch.norm(xyz, dim=-1, keepdim=True) # Calculate theta and phi theta = torch.acos(xyz[..., 2]) # data range [0, pi] phi = torch.atan2(xyz[..., 1], xyz[..., 0]) # data range [-pi, pi] # to [-1,1] theta = (theta / 3.141592653589793) * 2 - 1 phi = phi / 3.141592653589793 # cat thetaphi = torch.cat([theta.unsqueeze(-1), phi.unsqueeze(-1)], dim=-1) return thetaphi