Spaces:
Running
Running
File size: 6,880 Bytes
73633b5 6576362 73633b5 6576362 dad0ef3 a50befe 73633b5 6576362 a6cd9e3 6576362 73633b5 6576362 73633b5 6576362 73633b5 6576362 73633b5 6576362 73633b5 6576362 73633b5 6576362 73633b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | # βββββββββββββββββββββββββββββββββββββββββββββ
# src/question_generator.py (v5)
# Fix: replaced removed pipeline("text2text-generation")
# with direct AutoTokenizer + T5ForConditionalGeneration
# inference, compatible with transformers >= 4.46.
# βββββββββββββββββββββββββββββββββββββββββββββ
import re
import sys, os
import warnings
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import QG_MODEL_NAME, MAX_QUESTIONS
warnings.filterwarnings("ignore") # suppress HuggingFace FutureWarnings
# ββ Lazy model cache (avoids re-loading on every Streamlit rerun) ββ
_tokenizer = None
_qg_model = None
def _get_model():
"""Load model once and cache in module-level globals."""
global _tokenizer, _qg_model
if _tokenizer is None or _qg_model is None:
print(f"[INFO] Loading QG model: {QG_MODEL_NAME} ...")
_tokenizer = AutoTokenizer.from_pretrained(QG_MODEL_NAME, use_fast=False)
_qg_model = T5ForConditionalGeneration.from_pretrained(QG_MODEL_NAME)
_qg_model.eval()
print("[INFO] Model loaded.")
return _tokenizer, _qg_model
def highlight_answer(sentence: str, answer: str) -> str:
"""Wrap answer with <hl> tags for the T5 model."""
pattern = re.compile(re.escape(answer), re.IGNORECASE)
result = pattern.sub(f"<hl> {answer} <hl>", sentence, count=1)
return result
def answer_is_addressable(question: str, answer: str) -> bool:
"""
Check that the question is actually ASKING FOR the answer.
Rejects:
- Circular: answer text appears in the question
e.g. Q: "What is the name of ISRO?" A: "The Indian Space Research Organisation"
(ISRO is an abbreviation of the answer β circular)
- Too vague: question is only 4 words or fewer
- No question word
- Answer is a substring of the question
"""
q = question.strip()
a = answer.strip()
# Must end with ?
if not q.endswith("?"):
return False
# Must have a question word
q_lower = q.lower()
if not any(q_lower.startswith(w) for w in
["what", "who", "when", "where", "which", "how", "why"]):
return False
# Must be at least 5 words
if len(q.split()) < 5:
return False
# Answer must NOT appear verbatim in the question
if a.lower() in q_lower:
return False
# Check abbreviation trap: if any word in the question is an abbreviation
# of the answer (e.g. "ISRO" in question, answer is "Indian Space Research...")
answer_words = [w.lower() for w in a.split() if len(w) > 1]
abbrev = "".join(w[0] for w in answer_words if w.isalpha())
if len(abbrev) >= 2 and abbrev.lower() in q_lower:
return False
# Reject questions asking about name/abbreviation β usually circular
circular_patterns = [
r"what (is|was|were) the (full |official )?name",
r"what (does|did) .{1,10} stand for",
r"what (is|was) the abbreviation",
r"what (is|was) .{1,15} also (known|called)",
]
for pat in circular_patterns:
if re.search(pat, q_lower):
return False
return True
def generate_question(sentence: str, answer: str) -> str | None:
"""
Generate a question for a (sentence, answer) pair.
Returns the best valid question string, or None.
Uses direct tokenizer + model inference (no pipeline()).
"""
tokenizer, qg_model = _get_model()
highlighted = highlight_answer(sentence, answer)
input_text = f"generate question: {highlighted}"
try:
inputs = tokenizer(
input_text,
return_tensors = "pt",
max_length = 512,
truncation = True,
)
with torch.no_grad():
output_ids = qg_model.generate(
inputs["input_ids"],
max_new_tokens = 64,
num_beams = 5,
num_return_sequences = 3,
early_stopping = True,
)
outputs = [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
except Exception as e:
print(f" [QG] Generation error: {e}")
return None
for q in outputs:
q = q.strip()
if not q.endswith("?"):
q += "?"
if answer_is_addressable(q, answer):
return q
return None
def generate_questions(sentence_answers: dict) -> list:
"""
For each (sentence β answer candidates), generate one good question.
Tries each answer candidate in priority order until one works.
"""
results = []
for sentence, candidates in sentence_answers.items():
if len(results) >= MAX_QUESTIONS:
break
generated = False
for answer in candidates:
if len(answer.strip()) < 2:
continue
question = generate_question(sentence, answer)
if question:
print(f" [QG] β Q: {question}")
print(f" A: {answer}")
results.append({
"question" : question,
"answer" : answer,
"sentence" : sentence,
})
generated = True
break
else:
print(f" [QG] β Rejected for answer '{answer}'")
if not generated:
print(f" [QG] β No valid question for: '{sentence[:60]}'")
return results
if __name__ == "__main__":
tests = [
# Good cases β specific named answers
("ISRO was founded in 1969 by Vikram Sarabhai.", "Vikram Sarabhai"),
("Aryabhata was India's first satellite, launched in 1975.", "Aryabhata"),
("The Chandrayaan-1 mission in 2008 discovered water on the Moon.", "2008"),
("Chandrayaan-3 landed near the lunar south pole in 2023.", "Chandrayaan-3"),
("The Taj Mahal was built by Shah Jahan in 1632 in Agra.", "Shah Jahan"),
# Bad cases β should all be rejected
("The Indian Space Research Organisation (ISRO) was founded in 1969.", "The Indian Space Research Organisation"),
("ISRO developed India's first satellite.", "India"),
]
print("\n=== QUESTION GENERATION TEST ===\n")
for sentence, answer in tests:
q = generate_question(sentence, answer)
status = "β" if q else "β (rejected)"
print(f" [{status}]")
print(f" Sentence: {sentence}")
print(f" Answer : {answer}")
print(f" Question: {q}\n") |