mcq_generator / src /distractor_generator.py
tanmmayyy's picture
Initial commit β€” MCQ Generator with T5 + NER + WordNet
73633b5
# ─────────────────────────────────────────────
# src/distractor_generator.py (v3)
# Distractors MUST be the same entity type
# as the correct answer.
# e.g. answer=PERSON β†’ distractors are PERSONs
# answer=DATE β†’ distractors are DATEs
# ─────────────────────────────────────────────
import random
import sys, os
import nltk
from nltk.corpus import wordnet
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import NUM_DISTRACTORS
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
def get_same_label_distractors(answer: str, answer_label: str,
all_entities: list, n: int) -> list:
"""
Find entities from the passage that have the SAME NER label as the answer.
This ensures distractors are always the same type as the answer.
all_entities is a list of {"text": str, "label": str} dicts.
"""
distractors = []
seen = {answer.lower()}
# First pass: exact same label
for ent in all_entities:
if ent["label"] == answer_label and ent["text"].lower() not in seen:
distractors.append(ent["text"])
seen.add(ent["text"].lower())
return distractors[:n]
def get_wordnet_distractors(answer: str, n: int) -> list:
"""WordNet hyponym siblings β€” same semantic category."""
answer_key = answer.lower().replace(" ", "_")
distractors = set()
synsets = wordnet.synsets(answer_key)
if not synsets:
for word in answer.split():
synsets += wordnet.synsets(word.lower())
for synset in synsets[:5]:
for hypernym in synset.hypernyms():
for hyponym in hypernym.hyponyms():
for lemma in hyponym.lemma_names():
word = lemma.replace("_", " ")
if word.lower() == answer.lower():
continue
if len(word) > 1:
distractors.add(word.title() if answer[0].isupper() else word)
if len(distractors) >= n * 3:
break
result = list(distractors)
random.shuffle(result)
return result[:n]
def get_distractors(answer: str, all_entities: list,
passage_doc=None, n: int = NUM_DISTRACTORS) -> list:
"""
Main distractor function.
Strategy:
1. Same-label entities from the passage (best β€” contextual + same type)
2. WordNet siblings (good for common nouns)
3. Cross-label entities from passage (last resort, still real words)
Never mixes types if same-label gives enough results.
"""
collected = []
seen = {answer.lower()}
def add(candidates):
for c in candidates:
if isinstance(c, dict):
text = c["text"]
else:
text = c
if text.lower() not in seen and text.lower() != answer.lower():
seen.add(text.lower())
collected.append(text)
# Find the answer's NER label from the entity list
answer_label = ""
for ent in all_entities:
if ent["text"].lower() == answer.lower():
answer_label = ent["label"]
break
# Fuzzy match if exact not found
if not answer_label:
for ent in all_entities:
if answer.lower() in ent["text"].lower():
answer_label = ent["label"]
break
# Strategy 1: same label from passage
add(get_same_label_distractors(answer, answer_label, all_entities, n * 2))
# Strategy 2: WordNet
if len(collected) < n:
add(get_wordnet_distractors(answer, n * 2))
# Strategy 3: any other passage entity (cross-label fallback)
if len(collected) < n:
add(all_entities) # add() handles dedup
# Only if still short, add generic placeholders
placeholders = ["None of the above", "Cannot be determined", "All of the above"]
for p in placeholders:
if len(collected) >= n:
break
if p not in collected:
collected.append(p)
return collected[:n]
if __name__ == "__main__":
# Simulate entity list from preprocessor
entities = [
{"text": "ISRO", "label": "ORG"},
{"text": "NASA", "label": "ORG"},
{"text": "ESA", "label": "ORG"},
{"text": "Vikram Sarabhai", "label": "PERSON"},
{"text": "Vince McMahon", "label": "PERSON"},
{"text": "John Cena", "label": "PERSON"},
{"text": "1969", "label": "DATE"},
{"text": "1975", "label": "DATE"},
{"text": "2008", "label": "DATE"},
{"text": "India", "label": "GPE"},
{"text": "United States", "label": "GPE"},
{"text": "China", "label": "GPE"},
]
tests = [
("Vikram Sarabhai", "PERSON"),
("1969", "DATE"),
("India", "GPE"),
("ISRO", "ORG"),
]
print("=== DISTRACTOR TEST ===\n")
for answer, label in tests:
d = get_distractors(answer, entities)
print(f" Answer ({label:8s}): {answer:20s} β†’ {d}")