mshuaibi commited on
Commit
5258492
·
1 Parent(s): 5985166

add utils

Browse files
Files changed (2) hide show
  1. evaluator.py +1 -1
  2. evaluator_utils.py +53 -0
evaluator.py CHANGED
@@ -304,4 +304,4 @@ def evaluate(
304
  else:
305
  raise ValueError(f"Unknown eval_type: {eval_type}")
306
 
307
- return metrics
 
304
  else:
305
  raise ValueError(f"Unknown eval_type: {eval_type}")
306
 
307
+ return metrics
evaluator_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def reorder(ref: np.ndarray, to_reorder: np.ndarray) -> np.ndarray:
5
+ """
6
+ Get the ordering so that `to_reorder[ordering]` == ref.
7
+
8
+ eg:
9
+ ref = [c, a, b]
10
+ to_reorder = [b, a, c]
11
+ order = reorder(ref, to_reorder) # [2, 1, 0]
12
+ assert ref == to_reorder[order]
13
+
14
+ Parameters
15
+ ----------
16
+ ref : np.ndarray
17
+ Reference array. Must not contains duplicates.
18
+ to_reorder : np.ndarray
19
+ Array to re-order. Must not contains duplicates.
20
+ Items must be the same as in `ref`.
21
+
22
+ Returns
23
+ -------
24
+ np.ndarray
25
+ the ordering to apply on `to_reorder`
26
+ """
27
+ assert len(ref) == len(set(ref))
28
+ assert len(to_reorder) == len(set(to_reorder))
29
+ assert set(ref) == set(to_reorder)
30
+ item_to_idx = {item: idx for idx, item in enumerate(to_reorder)}
31
+ return np.array([item_to_idx[item] for item in ref])
32
+
33
+
34
+ def get_order(annotations_ids, submission_ids):
35
+ # Use sets for faster comparison
36
+ submission_set = set(submission_ids)
37
+ annotations_set = set(annotations_ids)
38
+
39
+ if submission_set != annotations_set:
40
+ missing_ids = annotations_set - submission_set
41
+ unexpected_ids = submission_set - annotations_set
42
+
43
+ details = (
44
+ f"{len(missing_ids)} missing IDs: ({list(missing_ids)[:3]}, ...)\n"
45
+ f"{len(unexpected_ids)} unexpected IDs: ({list(unexpected_ids)[:3]}, ...)"
46
+ )
47
+ raise Exception(f"IDs don't match.\n{details}")
48
+
49
+ assert len(submission_ids) == len(
50
+ submission_set
51
+ ), "Duplicate IDs found in submission."
52
+
53
+ return reorder(annotations_ids, submission_ids)