tanmmayyy commited on
Commit
6576362
Β·
1 Parent(s): dad0ef3
Files changed (1) hide show
  1. src/question_generator.py +40 -37
src/question_generator.py CHANGED
@@ -1,46 +1,35 @@
1
  # ─────────────────────────────────────────────
2
- # src/question_generator.py (v4)
3
- # Key fix: validate that the generated question
4
- # actually targets the intended answer.
5
- # Also filters circular questions like
6
- # "What is the name of X?" when answer IS X.
7
  # ─────────────────────────────────────────────
8
 
9
- from transformers import pipeline
10
  import re
11
  import sys, os
 
 
12
  from transformers import AutoTokenizer, T5ForConditionalGeneration
13
- import streamlit as st
14
-
15
- @st.cache_resource
16
- def load_model():
17
- tokenizer = AutoTokenizer.from_pretrained("valhalla/t5-small-qg-hl", use_fast=False)
18
- model = T5ForConditionalGeneration.from_pretrained("valhalla/t5-small-qg-hl")
19
- model.eval()
20
- return tokenizer, model
21
-
22
- tokenizer, qg_model = load_model()
23
-
24
 
25
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
26
  from config import QG_MODEL_NAME, MAX_QUESTIONS
27
 
28
- print(f"[INFO] Loading QG model: {QG_MODEL_NAME} ...")
29
- import warnings
30
  warnings.filterwarnings("ignore") # suppress HuggingFace FutureWarnings
31
 
 
 
 
32
 
33
-
34
- try:
35
- qg_pipeline = pipeline(
36
- "text2text-generation",
37
- model = QG_MODEL_NAME,
38
- tokenizer = QG_MODEL_NAME,
39
- )
40
- print("[INFO] Model loaded.")
41
- except Exception as e:
42
- print(f"[ERROR] {e}")
43
- raise
44
 
45
 
46
  def highlight_answer(sentence: str, answer: str) -> str:
@@ -108,24 +97,38 @@ def generate_question(sentence: str, answer: str) -> str | None:
108
  """
109
  Generate a question for a (sentence, answer) pair.
110
  Returns the best valid question string, or None.
 
111
  """
 
 
112
  highlighted = highlight_answer(sentence, answer)
113
  input_text = f"generate question: {highlighted}"
114
 
115
  try:
116
- outputs = qg_pipeline(
117
  input_text,
118
- max_new_tokens = 64,
119
- num_beams = 5,
120
- num_return_sequences = 3,
121
- early_stopping = True,
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
124
  print(f" [QG] Generation error: {e}")
125
  return None
126
 
127
- for output in outputs:
128
- q = output["generated_text"].strip()
129
  if not q.endswith("?"):
130
  q += "?"
131
  if answer_is_addressable(q, answer):
 
1
  # ─────────────────────────────────────────────
2
+ # src/question_generator.py (v5)
3
+ # Fix: replaced removed pipeline("text2text-generation")
4
+ # with direct AutoTokenizer + T5ForConditionalGeneration
5
+ # inference, compatible with transformers >= 4.46.
 
6
  # ─────────────────────────────────────────────
7
 
 
8
  import re
9
  import sys, os
10
+ import warnings
11
+ import torch
12
  from transformers import AutoTokenizer, T5ForConditionalGeneration
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
  from config import QG_MODEL_NAME, MAX_QUESTIONS
16
 
 
 
17
  warnings.filterwarnings("ignore") # suppress HuggingFace FutureWarnings
18
 
19
+ # ── Lazy model cache (avoids re-loading on every Streamlit rerun) ──
20
+ _tokenizer = None
21
+ _qg_model = None
22
 
23
+ def _get_model():
24
+ """Load model once and cache in module-level globals."""
25
+ global _tokenizer, _qg_model
26
+ if _tokenizer is None or _qg_model is None:
27
+ print(f"[INFO] Loading QG model: {QG_MODEL_NAME} ...")
28
+ _tokenizer = AutoTokenizer.from_pretrained(QG_MODEL_NAME, use_fast=False)
29
+ _qg_model = T5ForConditionalGeneration.from_pretrained(QG_MODEL_NAME)
30
+ _qg_model.eval()
31
+ print("[INFO] Model loaded.")
32
+ return _tokenizer, _qg_model
 
33
 
34
 
35
  def highlight_answer(sentence: str, answer: str) -> str:
 
97
  """
98
  Generate a question for a (sentence, answer) pair.
99
  Returns the best valid question string, or None.
100
+ Uses direct tokenizer + model inference (no pipeline()).
101
  """
102
+ tokenizer, qg_model = _get_model()
103
+
104
  highlighted = highlight_answer(sentence, answer)
105
  input_text = f"generate question: {highlighted}"
106
 
107
  try:
108
+ inputs = tokenizer(
109
  input_text,
110
+ return_tensors = "pt",
111
+ max_length = 512,
112
+ truncation = True,
 
113
  )
114
+ with torch.no_grad():
115
+ output_ids = qg_model.generate(
116
+ inputs["input_ids"],
117
+ max_new_tokens = 64,
118
+ num_beams = 5,
119
+ num_return_sequences = 3,
120
+ early_stopping = True,
121
+ )
122
+ outputs = [
123
+ tokenizer.decode(ids, skip_special_tokens=True)
124
+ for ids in output_ids
125
+ ]
126
  except Exception as e:
127
  print(f" [QG] Generation error: {e}")
128
  return None
129
 
130
+ for q in outputs:
131
+ q = q.strip()
132
  if not q.endswith("?"):
133
  q += "?"
134
  if answer_is_addressable(q, answer):