Spaces:
Running
Running
File size: 5,392 Bytes
73633b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | # βββββββββββββββββββββββββββββββββββββββββββββ
# src/distractor_generator.py (v3)
# Distractors MUST be the same entity type
# as the correct answer.
# e.g. answer=PERSON β distractors are PERSONs
# answer=DATE β distractors are DATEs
# βββββββββββββββββββββββββββββββββββββββββββββ
import random
import sys, os
import nltk
from nltk.corpus import wordnet
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import NUM_DISTRACTORS
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
def get_same_label_distractors(answer: str, answer_label: str,
all_entities: list, n: int) -> list:
"""
Find entities from the passage that have the SAME NER label as the answer.
This ensures distractors are always the same type as the answer.
all_entities is a list of {"text": str, "label": str} dicts.
"""
distractors = []
seen = {answer.lower()}
# First pass: exact same label
for ent in all_entities:
if ent["label"] == answer_label and ent["text"].lower() not in seen:
distractors.append(ent["text"])
seen.add(ent["text"].lower())
return distractors[:n]
def get_wordnet_distractors(answer: str, n: int) -> list:
"""WordNet hyponym siblings β same semantic category."""
answer_key = answer.lower().replace(" ", "_")
distractors = set()
synsets = wordnet.synsets(answer_key)
if not synsets:
for word in answer.split():
synsets += wordnet.synsets(word.lower())
for synset in synsets[:5]:
for hypernym in synset.hypernyms():
for hyponym in hypernym.hyponyms():
for lemma in hyponym.lemma_names():
word = lemma.replace("_", " ")
if word.lower() == answer.lower():
continue
if len(word) > 1:
distractors.add(word.title() if answer[0].isupper() else word)
if len(distractors) >= n * 3:
break
result = list(distractors)
random.shuffle(result)
return result[:n]
def get_distractors(answer: str, all_entities: list,
passage_doc=None, n: int = NUM_DISTRACTORS) -> list:
"""
Main distractor function.
Strategy:
1. Same-label entities from the passage (best β contextual + same type)
2. WordNet siblings (good for common nouns)
3. Cross-label entities from passage (last resort, still real words)
Never mixes types if same-label gives enough results.
"""
collected = []
seen = {answer.lower()}
def add(candidates):
for c in candidates:
if isinstance(c, dict):
text = c["text"]
else:
text = c
if text.lower() not in seen and text.lower() != answer.lower():
seen.add(text.lower())
collected.append(text)
# Find the answer's NER label from the entity list
answer_label = ""
for ent in all_entities:
if ent["text"].lower() == answer.lower():
answer_label = ent["label"]
break
# Fuzzy match if exact not found
if not answer_label:
for ent in all_entities:
if answer.lower() in ent["text"].lower():
answer_label = ent["label"]
break
# Strategy 1: same label from passage
add(get_same_label_distractors(answer, answer_label, all_entities, n * 2))
# Strategy 2: WordNet
if len(collected) < n:
add(get_wordnet_distractors(answer, n * 2))
# Strategy 3: any other passage entity (cross-label fallback)
if len(collected) < n:
add(all_entities) # add() handles dedup
# Only if still short, add generic placeholders
placeholders = ["None of the above", "Cannot be determined", "All of the above"]
for p in placeholders:
if len(collected) >= n:
break
if p not in collected:
collected.append(p)
return collected[:n]
if __name__ == "__main__":
# Simulate entity list from preprocessor
entities = [
{"text": "ISRO", "label": "ORG"},
{"text": "NASA", "label": "ORG"},
{"text": "ESA", "label": "ORG"},
{"text": "Vikram Sarabhai", "label": "PERSON"},
{"text": "Vince McMahon", "label": "PERSON"},
{"text": "John Cena", "label": "PERSON"},
{"text": "1969", "label": "DATE"},
{"text": "1975", "label": "DATE"},
{"text": "2008", "label": "DATE"},
{"text": "India", "label": "GPE"},
{"text": "United States", "label": "GPE"},
{"text": "China", "label": "GPE"},
]
tests = [
("Vikram Sarabhai", "PERSON"),
("1969", "DATE"),
("India", "GPE"),
("ISRO", "ORG"),
]
print("=== DISTRACTOR TEST ===\n")
for answer, label in tests:
d = get_distractors(answer, entities)
print(f" Answer ({label:8s}): {answer:20s} β {d}") |