j-js commited on
Commit
00b3a52
·
verified ·
1 Parent(s): dda1483

Update generator_engine.py

Browse files
Files changed (1) hide show
  1. generator_engine.py +127 -80
generator_engine.py CHANGED
@@ -1,82 +1,129 @@
1
  from __future__ import annotations
2
 
3
-
4
- def normalize_category(category: str | None) -> str:
5
- c = (category or "").strip().lower()
6
-
7
- if c in {"quantitative", "quant", "q", "math"}:
8
- return "Quantitative"
9
- if c in {"datainsight", "data_insight", "data insight", "di", "data"}:
10
- return "DataInsight"
11
- if c in {"verbal", "v"}:
12
- return "Verbal"
13
- if c in {"general", "", "unknown", "none", "null"}:
14
- return "General"
15
-
16
- return category or "General"
17
-
18
-
19
- def classify_question(question_text: str, category: str | None = None) -> dict:
20
- q = (question_text or "").lower()
21
- normalized = normalize_category(category)
22
-
23
- if normalized == "Quantitative":
24
- if ("percent" in q or "%" in q) and any(
25
- k in q for k in ["then", "after", "followed by", "successive", "increase", "decrease", "discount"]
26
- ):
27
- return {"category": normalized, "topic": "percent", "type": "successive_percent"}
28
-
29
- if "percent" in q or "%" in q:
30
- return {"category": normalized, "topic": "percent", "type": "percent_change"}
31
-
32
- if "ratio" in q or ":" in q:
33
- return {"category": normalized, "topic": "ratio", "type": "ratio_total"}
34
-
35
- if "probability" in q or "chosen at random" in q:
36
- return {"category": normalized, "topic": "probability", "type": "simple_probability"}
37
-
38
- if "divisible" in q or "remainder" in q or "mod" in q:
39
- return {"category": normalized, "topic": "number_theory", "type": "remainder_or_divisibility"}
40
-
41
- if "|" in q:
42
- return {"category": normalized, "topic": "algebra", "type": "absolute_value"}
43
-
44
- if any(k in q for k in ["circle", "radius", "circumference", "triangle", "perimeter", "area"]):
45
- return {"category": normalized, "topic": "geometry", "type": "geometry"}
46
-
47
- if any(k in q for k in ["average", "mean", "median"]):
48
- return {"category": normalized, "topic": "statistics", "type": "average"}
49
-
50
- if "sequence" in q:
51
- return {"category": normalized, "topic": "sequence", "type": "sequence"}
52
-
53
- if "=" in q:
54
- return {"category": normalized, "topic": "algebra", "type": "equation"}
55
-
56
- return {"category": normalized, "topic": "quant", "type": "general"}
57
-
58
- if normalized == "DataInsight":
59
- if "percent" in q or "%" in q:
60
- return {"category": normalized, "topic": "percent", "type": "percent_change"}
61
- if any(k in q for k in ["mean", "median", "distribution"]):
62
- return {"category": normalized, "topic": "statistics", "type": "distribution"}
63
- if any(k in q for k in ["correlation", "scatter", "trend", "table", "chart"]):
64
- return {"category": normalized, "topic": "data", "type": "correlation_or_graph"}
65
- return {"category": normalized, "topic": "data", "type": "general"}
66
-
67
- if normalized == "Verbal":
68
- if "meaning" in q or "definition" in q:
69
- return {"category": normalized, "topic": "vocabulary", "type": "definition"}
70
- if "grammatically" in q or "sentence correction" in q:
71
- return {"category": normalized, "topic": "grammar", "type": "sentence_correction"}
72
- if "argument" in q or "author" in q:
73
- return {"category": normalized, "topic": "reasoning", "type": "argument_analysis"}
74
- return {"category": normalized, "topic": "verbal", "type": "general"}
75
-
76
- if any(k in q for k in ["percent", "%", "ratio", "remainder", "divisible", "probability", "circle", "triangle", "="]):
77
- return classify_question(question_text, "Quantitative")
78
-
79
- if any(k in q for k in ["table", "chart", "scatter", "trend", "distribution"]):
80
- return classify_question(question_text, "DataInsight")
81
-
82
- return {"category": "General", "topic": "unknown", "type": "unknown"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from typing import List, Optional
4
+
5
+ try:
6
+ from transformers import pipeline
7
+ except Exception:
8
+ pipeline = None
9
+
10
+ from models import RetrievedChunk
11
+
12
+
13
+ class GeneratorEngine:
14
+ def __init__(self, model_name: str = "google/flan-t5-small"):
15
+ self.model_name = model_name
16
+ self.pipe = None
17
+
18
+ if pipeline is not None:
19
+ try:
20
+ self.pipe = pipeline("text2text-generation", model=model_name)
21
+ except Exception:
22
+ self.pipe = None
23
+
24
+ def available(self) -> bool:
25
+ return self.pipe is not None
26
+
27
+ def _notes_block(self, retrieval_context: List[RetrievedChunk]) -> str:
28
+ if not retrieval_context:
29
+ return ""
30
+ lines = []
31
+ for chunk in retrieval_context[:3]:
32
+ text = (chunk.text or "").strip().replace("\n", " ")
33
+ if len(text) > 220:
34
+ text = text[:217].rstrip() + "…"
35
+ lines.append(f"- {chunk.topic}: {text}")
36
+ return "\n".join(lines)
37
+
38
+ def _template_fallback(
39
+ self,
40
+ user_text: str,
41
+ question_text: Optional[str],
42
+ topic: str,
43
+ intent: str,
44
+ retrieval_context: Optional[List[RetrievedChunk]] = None,
45
+ ) -> str:
46
+ notes = self._notes_block(retrieval_context or [])
47
+
48
+ if intent == "hint":
49
+ base = "Start by identifying the exact relationship between the quantities before doing any arithmetic."
50
+ elif intent in {"instruction", "method"}:
51
+ base = "Translate the wording into an equation, ratio, or percent relationship, then solve one step at a time."
52
+ elif intent in {"walkthrough", "step_by_step", "explain", "concept"}:
53
+ base = "First identify what the question is asking, then map the values into the correct quantitative structure, and only then compute."
54
+ else:
55
+ base = "This does not match a strong solver rule yet, so begin by identifying the target quantity and the relationship connecting the numbers."
56
+
57
+ if notes:
58
+ return f"{base}\n\nRelevant notes:\n{notes}"
59
+ return base
60
+
61
+ def _build_prompt(
62
+ self,
63
+ user_text: str,
64
+ question_text: Optional[str],
65
+ topic: str,
66
+ intent: str,
67
+ retrieval_context: Optional[List[RetrievedChunk]] = None,
68
+ ) -> str:
69
+ question = (question_text or user_text or "").strip()
70
+ notes = self._notes_block(retrieval_context or [])
71
+
72
+ prompt = [
73
+ "You are a concise GMAT tutor.",
74
+ f"Topic: {topic or 'general'}",
75
+ f"Intent: {intent or 'answer'}",
76
+ "",
77
+ f"Question: {question}",
78
+ ]
79
+
80
+ if notes:
81
+ prompt.extend(["", "Relevant teaching notes:", notes])
82
+
83
+ prompt.extend(
84
+ [
85
+ "",
86
+ "Respond briefly and clearly.",
87
+ "If the problem is not fully solvable from the parse, give the next best method step.",
88
+ "Do not invent facts.",
89
+ ]
90
+ )
91
+
92
+ return "\n".join(prompt)
93
+
94
+ def generate(
95
+ self,
96
+ user_text: str,
97
+ question_text: Optional[str] = None,
98
+ topic: str = "",
99
+ intent: str = "answer",
100
+ retrieval_context: Optional[List[RetrievedChunk]] = None,
101
+ chat_history=None,
102
+ max_new_tokens: int = 96,
103
+ **kwargs,
104
+ ) -> Optional[str]:
105
+ prompt = self._build_prompt(
106
+ user_text=user_text,
107
+ question_text=question_text,
108
+ topic=topic,
109
+ intent=intent,
110
+ retrieval_context=retrieval_context or [],
111
+ )
112
+
113
+ if self.pipe is not None:
114
+ try:
115
+ out = self.pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False)
116
+ if out and isinstance(out, list):
117
+ text = str(out[0].get("generated_text", "")).strip()
118
+ if text:
119
+ return text
120
+ except Exception:
121
+ pass
122
+
123
+ return self._template_fallback(
124
+ user_text=user_text,
125
+ question_text=question_text,
126
+ topic=topic,
127
+ intent=intent,
128
+ retrieval_context=retrieval_context or [],
129
+ )