File size: 6,880 Bytes
73633b5
6576362
 
 
 
73633b5
 
 
 
6576362
 
dad0ef3
a50befe
73633b5
 
 
 
 
6576362
 
 
a6cd9e3
6576362
 
 
 
 
 
 
 
 
 
73633b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6576362
73633b5
6576362
 
73633b5
 
 
 
6576362
73633b5
6576362
 
 
73633b5
6576362
 
 
 
 
 
 
 
 
 
 
 
73633b5
 
 
 
6576362
 
73633b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# ─────────────────────────────────────────────
#  src/question_generator.py  (v5)
#  Fix: replaced removed pipeline("text2text-generation")
#  with direct AutoTokenizer + T5ForConditionalGeneration
#  inference, compatible with transformers >= 4.46.
# ─────────────────────────────────────────────

import re
import sys, os
import warnings
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import QG_MODEL_NAME, MAX_QUESTIONS

warnings.filterwarnings("ignore")   # suppress HuggingFace FutureWarnings

# ── Lazy model cache (avoids re-loading on every Streamlit rerun) ──
_tokenizer = None
_qg_model  = None

def _get_model():
    """Load model once and cache in module-level globals."""
    global _tokenizer, _qg_model
    if _tokenizer is None or _qg_model is None:
        print(f"[INFO] Loading QG model: {QG_MODEL_NAME} ...")
        _tokenizer = AutoTokenizer.from_pretrained(QG_MODEL_NAME, use_fast=False)
        _qg_model  = T5ForConditionalGeneration.from_pretrained(QG_MODEL_NAME)
        _qg_model.eval()
        print("[INFO] Model loaded.")
    return _tokenizer, _qg_model


def highlight_answer(sentence: str, answer: str) -> str:
    """Wrap answer with <hl> tags for the T5 model."""
    pattern = re.compile(re.escape(answer), re.IGNORECASE)
    result  = pattern.sub(f"<hl> {answer} <hl>", sentence, count=1)
    return result


def answer_is_addressable(question: str, answer: str) -> bool:
    """
    Check that the question is actually ASKING FOR the answer.
    
    Rejects:
    - Circular: answer text appears in the question
      e.g. Q: "What is the name of ISRO?" A: "The Indian Space Research Organisation"
           (ISRO is an abbreviation of the answer β€” circular)
    - Too vague: question is only 4 words or fewer
    - No question word
    - Answer is a substring of the question
    """
    q = question.strip()
    a = answer.strip()

    # Must end with ?
    if not q.endswith("?"):
        return False

    # Must have a question word
    q_lower = q.lower()
    if not any(q_lower.startswith(w) for w in
               ["what", "who", "when", "where", "which", "how", "why"]):
        return False

    # Must be at least 5 words
    if len(q.split()) < 5:
        return False

    # Answer must NOT appear verbatim in the question
    if a.lower() in q_lower:
        return False

    # Check abbreviation trap: if any word in the question is an abbreviation
    # of the answer (e.g. "ISRO" in question, answer is "Indian Space Research...")
    answer_words = [w.lower() for w in a.split() if len(w) > 1]
    abbrev = "".join(w[0] for w in answer_words if w.isalpha())
    if len(abbrev) >= 2 and abbrev.lower() in q_lower:
        return False

    # Reject questions asking about name/abbreviation β€” usually circular
    circular_patterns = [
        r"what (is|was|were) the (full |official )?name",
        r"what (does|did) .{1,10} stand for",
        r"what (is|was) the abbreviation",
        r"what (is|was) .{1,15} also (known|called)",
    ]
    for pat in circular_patterns:
        if re.search(pat, q_lower):
            return False

    return True


def generate_question(sentence: str, answer: str) -> str | None:
    """
    Generate a question for a (sentence, answer) pair.
    Returns the best valid question string, or None.
    Uses direct tokenizer + model inference (no pipeline()).
    """
    tokenizer, qg_model = _get_model()

    highlighted = highlight_answer(sentence, answer)
    input_text  = f"generate question: {highlighted}"

    try:
        inputs = tokenizer(
            input_text,
            return_tensors = "pt",
            max_length     = 512,
            truncation     = True,
        )
        with torch.no_grad():
            output_ids = qg_model.generate(
                inputs["input_ids"],
                max_new_tokens       = 64,
                num_beams            = 5,
                num_return_sequences = 3,
                early_stopping       = True,
            )
        outputs = [
            tokenizer.decode(ids, skip_special_tokens=True)
            for ids in output_ids
        ]
    except Exception as e:
        print(f"  [QG] Generation error: {e}")
        return None

    for q in outputs:
        q = q.strip()
        if not q.endswith("?"):
            q += "?"
        if answer_is_addressable(q, answer):
            return q

    return None


def generate_questions(sentence_answers: dict) -> list:
    """
    For each (sentence β†’ answer candidates), generate one good question.
    Tries each answer candidate in priority order until one works.
    """
    results = []

    for sentence, candidates in sentence_answers.items():
        if len(results) >= MAX_QUESTIONS:
            break

        generated = False
        for answer in candidates:
            if len(answer.strip()) < 2:
                continue

            question = generate_question(sentence, answer)

            if question:
                print(f"  [QG] βœ“ Q: {question}")
                print(f"         A: {answer}")
                results.append({
                    "question" : question,
                    "answer"   : answer,
                    "sentence" : sentence,
                })
                generated = True
                break
            else:
                print(f"  [QG] βœ— Rejected for answer '{answer}'")

        if not generated:
            print(f"  [QG] β€” No valid question for: '{sentence[:60]}'")

    return results


if __name__ == "__main__":
    tests = [
        # Good cases β€” specific named answers
        ("ISRO was founded in 1969 by Vikram Sarabhai.", "Vikram Sarabhai"),
        ("Aryabhata was India's first satellite, launched in 1975.", "Aryabhata"),
        ("The Chandrayaan-1 mission in 2008 discovered water on the Moon.", "2008"),
        ("Chandrayaan-3 landed near the lunar south pole in 2023.", "Chandrayaan-3"),
        ("The Taj Mahal was built by Shah Jahan in 1632 in Agra.", "Shah Jahan"),
        # Bad cases β€” should all be rejected
        ("The Indian Space Research Organisation (ISRO) was founded in 1969.", "The Indian Space Research Organisation"),
        ("ISRO developed India's first satellite.", "India"),
    ]

    print("\n=== QUESTION GENERATION TEST ===\n")
    for sentence, answer in tests:
        q = generate_question(sentence, answer)
        status = "βœ“" if q else "βœ— (rejected)"
        print(f"  [{status}]")
        print(f"    Sentence: {sentence}")
        print(f"    Answer  : {answer}")
        print(f"    Question: {q}\n")