import torch import numpy as np import healpy as hp from sklearn.neighbors import BallTree def rad_to_xyz(lonlat: torch.Tensor): """Convert lon-lat (in radians) to unit sphere xyz.""" lon = lonlat[..., 0] lat = lonlat[..., 1] x = torch.cos(lat) * torch.cos(lon) y = torch.cos(lat) * torch.sin(lon) z = torch.sin(lat) return torch.stack([x, y, z], axis=-1) def get_healpix_grid(nside: int) -> torch.Tensor: """Return HEALPix grid coordinates as array of shape (npix, 2).""" indices = np.arange(hp.nside2npix(nside)) theta, phi = hp.pix2ang(nside, indices, nest=True) phi = np.rad2deg(phi) theta = (90. - np.rad2deg(theta)) phi = torch.from_numpy(phi) theta = torch.from_numpy(theta) return torch.stack((phi, theta), axis=-1).float() def get_neighbors(pos_from: np.ndarray, pos_to: np.ndarray, k: int = 8) -> tuple: """Build a BallTree and query k nearest neighbors with haversine metric.""" pos_from_rad = pos_from[:, ::-1] pos_to_rad = pos_to[:, ::-1] tree = BallTree(pos_from_rad, metric='haversine') _, neighbors = tree.query(pos_to_rad, k=k) return neighbors