mcq_generator / src /question_generator.py
tanmmayyy's picture
final
6576362
# ─────────────────────────────────────────────
# src/question_generator.py (v5)
# Fix: replaced removed pipeline("text2text-generation")
# with direct AutoTokenizer + T5ForConditionalGeneration
# inference, compatible with transformers >= 4.46.
# ─────────────────────────────────────────────
import re
import sys, os
import warnings
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import QG_MODEL_NAME, MAX_QUESTIONS
warnings.filterwarnings("ignore") # suppress HuggingFace FutureWarnings
# ── Lazy model cache (avoids re-loading on every Streamlit rerun) ──
_tokenizer = None
_qg_model = None
def _get_model():
"""Load model once and cache in module-level globals."""
global _tokenizer, _qg_model
if _tokenizer is None or _qg_model is None:
print(f"[INFO] Loading QG model: {QG_MODEL_NAME} ...")
_tokenizer = AutoTokenizer.from_pretrained(QG_MODEL_NAME, use_fast=False)
_qg_model = T5ForConditionalGeneration.from_pretrained(QG_MODEL_NAME)
_qg_model.eval()
print("[INFO] Model loaded.")
return _tokenizer, _qg_model
def highlight_answer(sentence: str, answer: str) -> str:
"""Wrap answer with <hl> tags for the T5 model."""
pattern = re.compile(re.escape(answer), re.IGNORECASE)
result = pattern.sub(f"<hl> {answer} <hl>", sentence, count=1)
return result
def answer_is_addressable(question: str, answer: str) -> bool:
"""
Check that the question is actually ASKING FOR the answer.
Rejects:
- Circular: answer text appears in the question
e.g. Q: "What is the name of ISRO?" A: "The Indian Space Research Organisation"
(ISRO is an abbreviation of the answer β€” circular)
- Too vague: question is only 4 words or fewer
- No question word
- Answer is a substring of the question
"""
q = question.strip()
a = answer.strip()
# Must end with ?
if not q.endswith("?"):
return False
# Must have a question word
q_lower = q.lower()
if not any(q_lower.startswith(w) for w in
["what", "who", "when", "where", "which", "how", "why"]):
return False
# Must be at least 5 words
if len(q.split()) < 5:
return False
# Answer must NOT appear verbatim in the question
if a.lower() in q_lower:
return False
# Check abbreviation trap: if any word in the question is an abbreviation
# of the answer (e.g. "ISRO" in question, answer is "Indian Space Research...")
answer_words = [w.lower() for w in a.split() if len(w) > 1]
abbrev = "".join(w[0] for w in answer_words if w.isalpha())
if len(abbrev) >= 2 and abbrev.lower() in q_lower:
return False
# Reject questions asking about name/abbreviation β€” usually circular
circular_patterns = [
r"what (is|was|were) the (full |official )?name",
r"what (does|did) .{1,10} stand for",
r"what (is|was) the abbreviation",
r"what (is|was) .{1,15} also (known|called)",
]
for pat in circular_patterns:
if re.search(pat, q_lower):
return False
return True
def generate_question(sentence: str, answer: str) -> str | None:
"""
Generate a question for a (sentence, answer) pair.
Returns the best valid question string, or None.
Uses direct tokenizer + model inference (no pipeline()).
"""
tokenizer, qg_model = _get_model()
highlighted = highlight_answer(sentence, answer)
input_text = f"generate question: {highlighted}"
try:
inputs = tokenizer(
input_text,
return_tensors = "pt",
max_length = 512,
truncation = True,
)
with torch.no_grad():
output_ids = qg_model.generate(
inputs["input_ids"],
max_new_tokens = 64,
num_beams = 5,
num_return_sequences = 3,
early_stopping = True,
)
outputs = [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
except Exception as e:
print(f" [QG] Generation error: {e}")
return None
for q in outputs:
q = q.strip()
if not q.endswith("?"):
q += "?"
if answer_is_addressable(q, answer):
return q
return None
def generate_questions(sentence_answers: dict) -> list:
"""
For each (sentence β†’ answer candidates), generate one good question.
Tries each answer candidate in priority order until one works.
"""
results = []
for sentence, candidates in sentence_answers.items():
if len(results) >= MAX_QUESTIONS:
break
generated = False
for answer in candidates:
if len(answer.strip()) < 2:
continue
question = generate_question(sentence, answer)
if question:
print(f" [QG] βœ“ Q: {question}")
print(f" A: {answer}")
results.append({
"question" : question,
"answer" : answer,
"sentence" : sentence,
})
generated = True
break
else:
print(f" [QG] βœ— Rejected for answer '{answer}'")
if not generated:
print(f" [QG] β€” No valid question for: '{sentence[:60]}'")
return results
if __name__ == "__main__":
tests = [
# Good cases β€” specific named answers
("ISRO was founded in 1969 by Vikram Sarabhai.", "Vikram Sarabhai"),
("Aryabhata was India's first satellite, launched in 1975.", "Aryabhata"),
("The Chandrayaan-1 mission in 2008 discovered water on the Moon.", "2008"),
("Chandrayaan-3 landed near the lunar south pole in 2023.", "Chandrayaan-3"),
("The Taj Mahal was built by Shah Jahan in 1632 in Agra.", "Shah Jahan"),
# Bad cases β€” should all be rejected
("The Indian Space Research Organisation (ISRO) was founded in 1969.", "The Indian Space Research Organisation"),
("ISRO developed India's first satellite.", "India"),
]
print("\n=== QUESTION GENERATION TEST ===\n")
for sentence, answer in tests:
q = generate_question(sentence, answer)
status = "βœ“" if q else "βœ— (rejected)"
print(f" [{status}]")
print(f" Sentence: {sentence}")
print(f" Answer : {answer}")
print(f" Question: {q}\n")