Spaces:
Running
Running
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # src/question_generator.py (v5) | |
| # Fix: replaced removed pipeline("text2text-generation") | |
| # with direct AutoTokenizer + T5ForConditionalGeneration | |
| # inference, compatible with transformers >= 4.46. | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| import re | |
| import sys, os | |
| import warnings | |
| import torch | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from config import QG_MODEL_NAME, MAX_QUESTIONS | |
| warnings.filterwarnings("ignore") # suppress HuggingFace FutureWarnings | |
| # ββ Lazy model cache (avoids re-loading on every Streamlit rerun) ββ | |
| _tokenizer = None | |
| _qg_model = None | |
| def _get_model(): | |
| """Load model once and cache in module-level globals.""" | |
| global _tokenizer, _qg_model | |
| if _tokenizer is None or _qg_model is None: | |
| print(f"[INFO] Loading QG model: {QG_MODEL_NAME} ...") | |
| _tokenizer = AutoTokenizer.from_pretrained(QG_MODEL_NAME, use_fast=False) | |
| _qg_model = T5ForConditionalGeneration.from_pretrained(QG_MODEL_NAME) | |
| _qg_model.eval() | |
| print("[INFO] Model loaded.") | |
| return _tokenizer, _qg_model | |
| def highlight_answer(sentence: str, answer: str) -> str: | |
| """Wrap answer with <hl> tags for the T5 model.""" | |
| pattern = re.compile(re.escape(answer), re.IGNORECASE) | |
| result = pattern.sub(f"<hl> {answer} <hl>", sentence, count=1) | |
| return result | |
| def answer_is_addressable(question: str, answer: str) -> bool: | |
| """ | |
| Check that the question is actually ASKING FOR the answer. | |
| Rejects: | |
| - Circular: answer text appears in the question | |
| e.g. Q: "What is the name of ISRO?" A: "The Indian Space Research Organisation" | |
| (ISRO is an abbreviation of the answer β circular) | |
| - Too vague: question is only 4 words or fewer | |
| - No question word | |
| - Answer is a substring of the question | |
| """ | |
| q = question.strip() | |
| a = answer.strip() | |
| # Must end with ? | |
| if not q.endswith("?"): | |
| return False | |
| # Must have a question word | |
| q_lower = q.lower() | |
| if not any(q_lower.startswith(w) for w in | |
| ["what", "who", "when", "where", "which", "how", "why"]): | |
| return False | |
| # Must be at least 5 words | |
| if len(q.split()) < 5: | |
| return False | |
| # Answer must NOT appear verbatim in the question | |
| if a.lower() in q_lower: | |
| return False | |
| # Check abbreviation trap: if any word in the question is an abbreviation | |
| # of the answer (e.g. "ISRO" in question, answer is "Indian Space Research...") | |
| answer_words = [w.lower() for w in a.split() if len(w) > 1] | |
| abbrev = "".join(w[0] for w in answer_words if w.isalpha()) | |
| if len(abbrev) >= 2 and abbrev.lower() in q_lower: | |
| return False | |
| # Reject questions asking about name/abbreviation β usually circular | |
| circular_patterns = [ | |
| r"what (is|was|were) the (full |official )?name", | |
| r"what (does|did) .{1,10} stand for", | |
| r"what (is|was) the abbreviation", | |
| r"what (is|was) .{1,15} also (known|called)", | |
| ] | |
| for pat in circular_patterns: | |
| if re.search(pat, q_lower): | |
| return False | |
| return True | |
| def generate_question(sentence: str, answer: str) -> str | None: | |
| """ | |
| Generate a question for a (sentence, answer) pair. | |
| Returns the best valid question string, or None. | |
| Uses direct tokenizer + model inference (no pipeline()). | |
| """ | |
| tokenizer, qg_model = _get_model() | |
| highlighted = highlight_answer(sentence, answer) | |
| input_text = f"generate question: {highlighted}" | |
| try: | |
| inputs = tokenizer( | |
| input_text, | |
| return_tensors = "pt", | |
| max_length = 512, | |
| truncation = True, | |
| ) | |
| with torch.no_grad(): | |
| output_ids = qg_model.generate( | |
| inputs["input_ids"], | |
| max_new_tokens = 64, | |
| num_beams = 5, | |
| num_return_sequences = 3, | |
| early_stopping = True, | |
| ) | |
| outputs = [ | |
| tokenizer.decode(ids, skip_special_tokens=True) | |
| for ids in output_ids | |
| ] | |
| except Exception as e: | |
| print(f" [QG] Generation error: {e}") | |
| return None | |
| for q in outputs: | |
| q = q.strip() | |
| if not q.endswith("?"): | |
| q += "?" | |
| if answer_is_addressable(q, answer): | |
| return q | |
| return None | |
| def generate_questions(sentence_answers: dict) -> list: | |
| """ | |
| For each (sentence β answer candidates), generate one good question. | |
| Tries each answer candidate in priority order until one works. | |
| """ | |
| results = [] | |
| for sentence, candidates in sentence_answers.items(): | |
| if len(results) >= MAX_QUESTIONS: | |
| break | |
| generated = False | |
| for answer in candidates: | |
| if len(answer.strip()) < 2: | |
| continue | |
| question = generate_question(sentence, answer) | |
| if question: | |
| print(f" [QG] β Q: {question}") | |
| print(f" A: {answer}") | |
| results.append({ | |
| "question" : question, | |
| "answer" : answer, | |
| "sentence" : sentence, | |
| }) | |
| generated = True | |
| break | |
| else: | |
| print(f" [QG] β Rejected for answer '{answer}'") | |
| if not generated: | |
| print(f" [QG] β No valid question for: '{sentence[:60]}'") | |
| return results | |
| if __name__ == "__main__": | |
| tests = [ | |
| # Good cases β specific named answers | |
| ("ISRO was founded in 1969 by Vikram Sarabhai.", "Vikram Sarabhai"), | |
| ("Aryabhata was India's first satellite, launched in 1975.", "Aryabhata"), | |
| ("The Chandrayaan-1 mission in 2008 discovered water on the Moon.", "2008"), | |
| ("Chandrayaan-3 landed near the lunar south pole in 2023.", "Chandrayaan-3"), | |
| ("The Taj Mahal was built by Shah Jahan in 1632 in Agra.", "Shah Jahan"), | |
| # Bad cases β should all be rejected | |
| ("The Indian Space Research Organisation (ISRO) was founded in 1969.", "The Indian Space Research Organisation"), | |
| ("ISRO developed India's first satellite.", "India"), | |
| ] | |
| print("\n=== QUESTION GENERATION TEST ===\n") | |
| for sentence, answer in tests: | |
| q = generate_question(sentence, answer) | |
| status = "β" if q else "β (rejected)" | |
| print(f" [{status}]") | |
| print(f" Sentence: {sentence}") | |
| print(f" Answer : {answer}") | |
| print(f" Question: {q}\n") |