File size: 1,158 Bytes
5f226eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import numpy as np
import healpy as hp
from sklearn.neighbors import BallTree


def rad_to_xyz(lonlat: torch.Tensor):
    """Convert lon-lat (in radians) to unit sphere xyz."""
    lon = lonlat[..., 0]
    lat = lonlat[..., 1]

    x = torch.cos(lat) * torch.cos(lon)
    y = torch.cos(lat) * torch.sin(lon)
    z = torch.sin(lat)

    return torch.stack([x, y, z], axis=-1)


def get_healpix_grid(nside: int) -> torch.Tensor:
    """Return HEALPix grid coordinates as array of shape (npix, 2)."""
    indices = np.arange(hp.nside2npix(nside))
    theta, phi = hp.pix2ang(nside, indices, nest=True)

    phi = np.rad2deg(phi)
    theta = (90. - np.rad2deg(theta))

    phi = torch.from_numpy(phi)
    theta = torch.from_numpy(theta)

    return torch.stack((phi, theta), axis=-1).float()


def get_neighbors(pos_from: np.ndarray, pos_to: np.ndarray, k: int = 8) -> tuple:
    """Build a BallTree and query k nearest neighbors with haversine metric."""
    pos_from_rad = pos_from[:, ::-1]
    pos_to_rad = pos_to[:, ::-1]

    tree = BallTree(pos_from_rad, metric='haversine')
    _, neighbors = tree.query(pos_to_rad, k=k)
    return neighbors