File size: 5,392 Bytes
73633b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# ─────────────────────────────────────────────
#  src/distractor_generator.py  (v3)
#  Distractors MUST be the same entity type
#  as the correct answer.
#  e.g. answer=PERSON β†’ distractors are PERSONs
#       answer=DATE   β†’ distractors are DATEs
# ─────────────────────────────────────────────

import random
import sys, os

import nltk
from nltk.corpus import wordnet

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import NUM_DISTRACTORS

nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)


def get_same_label_distractors(answer: str, answer_label: str,
                                all_entities: list, n: int) -> list:
    """
    Find entities from the passage that have the SAME NER label as the answer.
    This ensures distractors are always the same type as the answer.

    all_entities is a list of {"text": str, "label": str} dicts.
    """
    distractors = []
    seen = {answer.lower()}

    # First pass: exact same label
    for ent in all_entities:
        if ent["label"] == answer_label and ent["text"].lower() not in seen:
            distractors.append(ent["text"])
            seen.add(ent["text"].lower())

    return distractors[:n]


def get_wordnet_distractors(answer: str, n: int) -> list:
    """WordNet hyponym siblings β€” same semantic category."""
    answer_key  = answer.lower().replace(" ", "_")
    distractors = set()

    synsets = wordnet.synsets(answer_key)
    if not synsets:
        for word in answer.split():
            synsets += wordnet.synsets(word.lower())

    for synset in synsets[:5]:
        for hypernym in synset.hypernyms():
            for hyponym in hypernym.hyponyms():
                for lemma in hyponym.lemma_names():
                    word = lemma.replace("_", " ")
                    if word.lower() == answer.lower():
                        continue
                    if len(word) > 1:
                        distractors.add(word.title() if answer[0].isupper() else word)
        if len(distractors) >= n * 3:
            break

    result = list(distractors)
    random.shuffle(result)
    return result[:n]


def get_distractors(answer: str, all_entities: list,
                     passage_doc=None, n: int = NUM_DISTRACTORS) -> list:
    """
    Main distractor function.
    Strategy:
      1. Same-label entities from the passage  (best β€” contextual + same type)
      2. WordNet siblings                       (good for common nouns)
      3. Cross-label entities from passage      (last resort, still real words)
    Never mixes types if same-label gives enough results.
    """
    collected = []
    seen      = {answer.lower()}

    def add(candidates):
        for c in candidates:
            if isinstance(c, dict):
                text = c["text"]
            else:
                text = c
            if text.lower() not in seen and text.lower() != answer.lower():
                seen.add(text.lower())
                collected.append(text)

    # Find the answer's NER label from the entity list
    answer_label = ""
    for ent in all_entities:
        if ent["text"].lower() == answer.lower():
            answer_label = ent["label"]
            break
    # Fuzzy match if exact not found
    if not answer_label:
        for ent in all_entities:
            if answer.lower() in ent["text"].lower():
                answer_label = ent["label"]
                break

    # Strategy 1: same label from passage
    add(get_same_label_distractors(answer, answer_label, all_entities, n * 2))

    # Strategy 2: WordNet
    if len(collected) < n:
        add(get_wordnet_distractors(answer, n * 2))

    # Strategy 3: any other passage entity (cross-label fallback)
    if len(collected) < n:
        add(all_entities)   # add() handles dedup

    # Only if still short, add generic placeholders
    placeholders = ["None of the above", "Cannot be determined", "All of the above"]
    for p in placeholders:
        if len(collected) >= n:
            break
        if p not in collected:
            collected.append(p)

    return collected[:n]


if __name__ == "__main__":
    # Simulate entity list from preprocessor
    entities = [
        {"text": "ISRO",            "label": "ORG"},
        {"text": "NASA",            "label": "ORG"},
        {"text": "ESA",             "label": "ORG"},
        {"text": "Vikram Sarabhai", "label": "PERSON"},
        {"text": "Vince McMahon",   "label": "PERSON"},
        {"text": "John Cena",       "label": "PERSON"},
        {"text": "1969",            "label": "DATE"},
        {"text": "1975",            "label": "DATE"},
        {"text": "2008",            "label": "DATE"},
        {"text": "India",           "label": "GPE"},
        {"text": "United States",   "label": "GPE"},
        {"text": "China",           "label": "GPE"},
    ]

    tests = [
        ("Vikram Sarabhai", "PERSON"),
        ("1969",            "DATE"),
        ("India",           "GPE"),
        ("ISRO",            "ORG"),
    ]

    print("=== DISTRACTOR TEST ===\n")
    for answer, label in tests:
        d = get_distractors(answer, entities)
        print(f"  Answer ({label:8s}): {answer:20s} β†’ {d}")