fairchem_leaderboard / evaluator_utils.py
mshuaibi's picture
add utils
5258492
import numpy as np
def reorder(ref: np.ndarray, to_reorder: np.ndarray) -> np.ndarray:
"""
Get the ordering so that `to_reorder[ordering]` == ref.
eg:
ref = [c, a, b]
to_reorder = [b, a, c]
order = reorder(ref, to_reorder) # [2, 1, 0]
assert ref == to_reorder[order]
Parameters
----------
ref : np.ndarray
Reference array. Must not contains duplicates.
to_reorder : np.ndarray
Array to re-order. Must not contains duplicates.
Items must be the same as in `ref`.
Returns
-------
np.ndarray
the ordering to apply on `to_reorder`
"""
assert len(ref) == len(set(ref))
assert len(to_reorder) == len(set(to_reorder))
assert set(ref) == set(to_reorder)
item_to_idx = {item: idx for idx, item in enumerate(to_reorder)}
return np.array([item_to_idx[item] for item in ref])
def get_order(annotations_ids, submission_ids):
# Use sets for faster comparison
submission_set = set(submission_ids)
annotations_set = set(annotations_ids)
if submission_set != annotations_set:
missing_ids = annotations_set - submission_set
unexpected_ids = submission_set - annotations_set
details = (
f"{len(missing_ids)} missing IDs: ({list(missing_ids)[:3]}, ...)\n"
f"{len(unexpected_ids)} unexpected IDs: ({list(unexpected_ids)[:3]}, ...)"
)
raise Exception(f"IDs don't match.\n{details}")
assert len(submission_ids) == len(
submission_set
), "Duplicate IDs found in submission."
return reorder(annotations_ids, submission_ids)