j-js commited on
Commit
670ac1a
·
verified ·
1 Parent(s): 00b3a52

Update quant_solver.py

Browse files
Files changed (1) hide show
  1. quant_solver.py +426 -120
quant_solver.py CHANGED
@@ -1,130 +1,436 @@
1
  from __future__ import annotations
2
 
3
- from typing import List, Optional
 
 
 
4
 
5
  try:
6
- from transformers import pipeline
7
  except Exception:
8
- pipeline = None
9
-
10
- from models import RetrievedChunk
11
-
12
-
13
- class GeneratorEngine:
14
- def __init__(self, model_name: str = "google/flan-t5-small"):
15
- self.model_name = model_name
16
- self.pipe = None
17
-
18
- if pipeline is not None:
19
- try:
20
- self.pipe = pipeline("text2text-generation", model=model_name)
21
- except Exception:
22
- self.pipe = None
23
-
24
- def available(self) -> bool:
25
- return self.pipe is not None
26
-
27
- def _notes_block(self, retrieval_context: List[RetrievedChunk]) -> str:
28
- if not retrieval_context:
29
- return ""
30
- lines = []
31
- for chunk in retrieval_context[:3]:
32
- text = (chunk.text or "").strip().replace("\n", " ")
33
- if len(text) > 220:
34
- text = text[:217].rstrip() + ""
35
- lines.append(f"- {chunk.topic}: {text}")
36
- return "\n".join(lines)
37
-
38
- def _template_fallback(
39
- self,
40
- user_text: str,
41
- question_text: Optional[str],
42
- topic: str,
43
- intent: str,
44
- retrieval_context: Optional[List[RetrievedChunk]] = None,
45
- ) -> str:
46
- question = (question_text or user_text or "").strip()
47
- notes = self._notes_block(retrieval_context or [])
48
-
49
- if intent == "hint":
50
- base = "Start by identifying the exact relationship between the quantities before doing any arithmetic."
51
- elif intent in {"instruction", "method"}:
52
- base = "Translate the wording into an equation, ratio, or percent relationship, then solve one step at a time."
53
- elif intent in {"walkthrough", "step_by_step", "explain", "concept"}:
54
- base = "First identify what the question is asking, then map the values into the correct quantitative structure, and only then compute."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  else:
56
- base = "This does not match a strong solver rule yet, so begin by identifying the target quantity and the relationship connecting the numbers."
57
-
58
- if notes:
59
- return f"{base}\n\nRelevant notes:\n{notes}"
60
- return base
61
-
62
- def _build_prompt(
63
- self,
64
- user_text: str,
65
- question_text: Optional[str],
66
- topic: str,
67
- intent: str,
68
- retrieval_context: Optional[List[RetrievedChunk]] = None,
69
- ) -> str:
70
- question = (question_text or user_text or "").strip()
71
- notes = self._notes_block(retrieval_context or [])
72
-
73
- prompt = [
74
- "You are a concise GMAT tutor.",
75
- f"Topic: {topic or 'general'}",
76
- f"Intent: {intent or 'answer'}",
77
- "",
78
- f"Question: {question}",
79
- ]
80
-
81
- if notes:
82
- prompt.extend(["", "Relevant teaching notes:", notes])
83
-
84
- prompt.extend(
85
- [
86
- "",
87
- "Respond briefly and clearly.",
88
- "If the problem is not fully solvable from the parse, give the next best method step.",
89
- "Do not invent facts.",
90
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
92
 
93
- return "\n".join(prompt)
94
-
95
- def generate(
96
- self,
97
- user_text: str,
98
- question_text: Optional[str] = None,
99
- topic: str = "",
100
- intent: str = "answer",
101
- retrieval_context: Optional[List[RetrievedChunk]] = None,
102
- chat_history=None,
103
- max_new_tokens: int = 96,
104
- **kwargs,
105
- ) -> Optional[str]:
106
- prompt = self._build_prompt(
107
- user_text=user_text,
108
- question_text=question_text,
109
- topic=topic,
110
- intent=intent,
111
- retrieval_context=retrieval_context or [],
112
  )
113
 
114
- if self.pipe is not None:
115
- try:
116
- out = self.pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False)
117
- if out and isinstance(out, list):
118
- text = str(out[0].get("generated_text", "")).strip()
119
- if text:
120
- return text
121
- except Exception:
122
- pass
123
-
124
- return self._template_fallback(
125
- user_text=user_text,
126
- question_text=question_text,
127
- topic=topic,
128
- intent=intent,
129
- retrieval_context=retrieval_context or [],
130
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import math
4
+ import re
5
+ from statistics import mean, median
6
+ from typing import Dict, List, Optional, Tuple
7
 
8
  try:
9
+ import sympy as sp
10
  except Exception:
11
+ sp = None
12
+
13
+ from models import SolverResult
14
+ from utils import clean_math_text, normalize_spaces
15
+
16
+
17
+ def extract_choices(text: str) -> Dict[str, str]:
18
+ text = text or ""
19
+ matches = list(
20
+ re.finditer(
21
+ r"(?i)\b([A-E])[\)\.:]\s*(.*?)(?=\s+\b[A-E][\)\.:]\s*|$)",
22
+ text,
23
+ )
24
+ )
25
+ return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches}
26
+
27
+
28
+ def has_answer_choices(text: str) -> bool:
29
+ return len(extract_choices(text)) >= 3
30
+
31
+
32
+ def is_quant_question(text: str) -> bool:
33
+ lower = clean_math_text(text).lower()
34
+ keywords = [
35
+ "solve", "equation", "percent", "ratio", "probability", "mean", "median",
36
+ "average", "sum", "difference", "product", "quotient", "triangle", "circle",
37
+ "rectangle", "area", "perimeter", "volume", "algebra", "integer", "divisible",
38
+ "number", "fraction", "decimal", "geometry", "distance", "speed", "work",
39
+ "remainder", "discount",
40
+ ]
41
+ if any(k in lower for k in keywords):
42
+ return True
43
+ if "=" in lower and re.search(r"[a-z]", lower):
44
+ return True
45
+ if re.search(r"\d", lower) and ("?" in lower or has_answer_choices(lower)):
46
+ return True
47
+ return False
48
+
49
+
50
+ def _prepare_expression(expr: str) -> str:
51
+ expr = clean_math_text(expr).strip()
52
+ expr = expr.replace("^", "**")
53
+ expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
54
+ expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
55
+ expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
56
+ return expr
57
+
58
+
59
+ def _extract_equation(text: str) -> Optional[str]:
60
+ cleaned = clean_math_text(text)
61
+ if "=" not in cleaned:
62
+ return None
63
+
64
+ patterns = [
65
+ r"([A-Za-z0-9\.\+\-\*/\^\(\)\s]*[a-zA-Z][A-Za-z0-9\.\+\-\*/\^\(\)\s]*=[A-Za-z0-9\.\+\-\*/\^\(\)\s]+)",
66
+ r"([0-9A-Za-z\.\+\-\*/\^\(\)\s]+=[0-9A-Za-z\.\+\-\*/\^\(\)\s]+)",
67
+ ]
68
+
69
+ for pattern in patterns:
70
+ for m in re.finditer(pattern, cleaned):
71
+ candidate = m.group(1).strip()
72
+ if re.search(r"[a-z]", candidate.lower()) and not candidate.lower().startswith(
73
+ ("how do", "can you", "please", "what is", "solve ")
74
+ ):
75
+ return candidate
76
+
77
+ eq_index = cleaned.find("=")
78
+ left = re.findall(r"[A-Za-z0-9\.\+\-\*/\^\(\)\s]+$", cleaned[:eq_index])
79
+ right = re.findall(r"^[A-Za-z0-9\.\+\-\*/\^\(\)\s]+", cleaned[eq_index + 1:])
80
+ if left and right:
81
+ candidate = left[0].strip().split()[-1] + " = " + right[0].strip().split()[0]
82
+ if re.search(r"[a-z]", candidate.lower()):
83
+ return candidate
84
+ return None
85
+
86
+
87
+ def _parse_number(text: str) -> Optional[float]:
88
+ raw = clean_math_text(text).strip().lower()
89
+
90
+ pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw.replace(" ", ""))
91
+ if pct:
92
+ return float(pct.group(1)) / 100.0
93
+
94
+ frac = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
95
+ if frac:
96
+ den = float(frac.group(2))
97
+ if den == 0:
98
+ return None
99
+ return float(frac.group(1)) / den
100
+
101
+ try:
102
+ return float(
103
+ eval(
104
+ _prepare_expression(raw),
105
+ {"__builtins__": {}},
106
+ {"sqrt": math.sqrt, "pi": math.pi},
107
+ )
108
+ )
109
+ except Exception:
110
+ return None
111
+
112
+
113
+ def _best_choice(answer_value: float, choices: Dict[str, str]) -> Optional[str]:
114
+ best_letter = None
115
+ best_diff = float("inf")
116
+
117
+ for letter, raw in choices.items():
118
+ parsed = _parse_number(raw)
119
+ if parsed is None:
120
+ continue
121
+ diff = abs(parsed - answer_value)
122
+ if diff < best_diff:
123
+ best_diff = diff
124
+ best_letter = letter
125
+
126
+ if best_letter is not None and best_diff <= 1e-6:
127
+ return best_letter
128
+ return None
129
+
130
+
131
+ def _make_result(
132
+ *,
133
+ topic: str,
134
+ answer_value: str,
135
+ internal_answer: Optional[str] = None,
136
+ steps: Optional[List[str]] = None,
137
+ choices_text: str = "",
138
+ ) -> SolverResult:
139
+ answer_float = _parse_number(answer_value)
140
+ choices = extract_choices(choices_text)
141
+ answer_letter = _best_choice(answer_float, choices) if (answer_float is not None and choices) else None
142
+
143
+ return SolverResult(
144
+ domain="quant",
145
+ solved=True,
146
+ topic=topic,
147
+ answer_value=answer_value,
148
+ answer_letter=answer_letter,
149
+ internal_answer=internal_answer or answer_value,
150
+ steps=steps or [],
151
+ )
152
+
153
+
154
+ def _solve_successive_percent(text: str) -> Optional[SolverResult]:
155
+ lower = clean_math_text(text).lower()
156
+
157
+ pattern = re.findall(
158
+ r"(increase|decrease|discount|mark(?:ed)?\s*up|mark(?:ed)?\s*down|rise|fall)\s+by\s+(\d+(?:\.\d+)?)\s*(?:%|percent)",
159
+ lower,
160
+ )
161
+ if len(pattern) < 2:
162
+ pattern = re.findall(
163
+ r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+(increase|decrease|discount|rise|fall)",
164
+ lower,
165
+ )
166
+ pattern = [(op, pct) for pct, op in pattern]
167
+
168
+ if len(pattern) < 2:
169
+ return None
170
+
171
+ multiplier = 1.0
172
+ step_lines: List[str] = []
173
+
174
+ for op, pct_raw in pattern:
175
+ pct = float(pct_raw)
176
+ if any(k in op for k in ["decrease", "discount", "down", "fall"]):
177
+ factor = 1 - pct / 100.0
178
+ step_lines.append(f"A {pct:g}% decrease means multiply by {factor:g}.")
179
+ else:
180
+ factor = 1 + pct / 100.0
181
+ step_lines.append(f"A {pct:g}% increase means multiply by {factor:g}.")
182
+ multiplier *= factor
183
+
184
+ net_change = (multiplier - 1.0) * 100.0
185
+ direction = "increase" if net_change >= 0 else "decrease"
186
+ magnitude = abs(net_change)
187
+
188
+ return _make_result(
189
+ topic="percent",
190
+ answer_value=f"{magnitude:g}%",
191
+ internal_answer=f"net {direction} of {magnitude:g}%",
192
+ steps=step_lines + [f"The combined multiplier gives a net {direction} of {magnitude:g}%."],
193
+ choices_text=text,
194
+ )
195
+
196
+
197
+ def _extract_ratio_labels(text: str) -> Optional[Tuple[str, str]]:
198
+ m = re.search(r"ratio of ([a-z ]+?) to ([a-z ]+?) is \d+\s*:\s*\d+", text.lower())
199
+ if not m:
200
+ return None
201
+ left = normalize_spaces(m.group(1)).rstrip("s")
202
+ right = normalize_spaces(m.group(2)).rstrip("s")
203
+ return left, right
204
+
205
+
206
+ def _solve_ratio_total(text: str) -> Optional[SolverResult]:
207
+ lower = clean_math_text(text).lower()
208
+
209
+ ratio_match = re.search(r"(\d+)\s*:\s*(\d+)", lower)
210
+ total_match = re.search(r"(?:total|altogether|in all|sum)\s*(?:is|=|of)?\s*(\d+)", lower)
211
+
212
+ if not ratio_match or not total_match:
213
+ return None
214
+
215
+ a = int(ratio_match.group(1))
216
+ b = int(ratio_match.group(2))
217
+ total = int(total_match.group(1))
218
+
219
+ part_sum = a + b
220
+ if part_sum == 0:
221
+ return None
222
+
223
+ unit = total / part_sum
224
+ left_value = a * unit
225
+ right_value = b * unit
226
+
227
+ labels = _extract_ratio_labels(lower)
228
+ requested_value = left_value
229
+ requested_label = "first quantity"
230
+
231
+ if labels:
232
+ left_label, right_label = labels
233
+ if left_label in lower and re.search(rf"how many {re.escape(left_label)}", lower):
234
+ requested_value = left_value
235
+ requested_label = left_label
236
+ elif right_label in lower and re.search(rf"how many {re.escape(right_label)}", lower):
237
+ requested_value = right_value
238
+ requested_label = right_label
239
  else:
240
+ requested_value = left_value
241
+ requested_label = left_label
242
+
243
+ return _make_result(
244
+ topic="ratio",
245
+ answer_value=f"{requested_value:g}",
246
+ internal_answer=f"{requested_label} = {requested_value:g}",
247
+ steps=[
248
+ f"Add the ratio parts: {a} + {b} = {part_sum}.",
249
+ f"Each ratio unit is {total} / {part_sum} = {unit:g}.",
250
+ f"Multiply by the required ratio part to get {requested_value:g}.",
251
+ ],
252
+ choices_text=text,
253
+ )
254
+
255
+
256
+ def _solve_remainder(text: str) -> Optional[SolverResult]:
257
+ lower = clean_math_text(text).lower()
258
+
259
+ m = re.search(r"remainder .*? when (\d+) is divided by (\d+)", lower)
260
+ if not m:
261
+ m = re.search(r"(\d+)\s*(?:mod|%)\s*(\d+)", lower)
262
+ if not m:
263
+ return None
264
+
265
+ a = int(m.group(1))
266
+ b = int(m.group(2))
267
+ if b == 0:
268
+ return None
269
+
270
+ r = a % b
271
+
272
+ return _make_result(
273
+ topic="number_theory",
274
+ answer_value=str(r),
275
+ internal_answer=str(r),
276
+ steps=[
277
+ f"Divide {a} by {b}.",
278
+ f"The remainder is {a} mod {b} = {r}.",
279
+ ],
280
+ choices_text=text,
281
+ )
282
+
283
+
284
+ def _solve_percent(text: str) -> Optional[SolverResult]:
285
+ lower = clean_math_text(text).lower()
286
+ choices = extract_choices(text)
287
+
288
+ m = re.search(r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(?:a\s+)?number\s+is\s+(\d+(?:\.\d+)?)", lower)
289
+ if m:
290
+ p = float(m.group(1))
291
+ value = float(m.group(2))
292
+ ans = value / (p / 100.0)
293
+ answer_letter = _best_choice(ans, choices) if choices else None
294
+
295
+ return SolverResult(
296
+ domain="quant",
297
+ solved=True,
298
+ topic="percent",
299
+ answer_value=f"{ans:g}",
300
+ answer_letter=answer_letter,
301
+ internal_answer=f"{ans:g}",
302
+ steps=[
303
+ "Let the number be n.",
304
+ f"Write {p}% of n as {p / 100:g}n.",
305
+ f"Set {p / 100:g}n = {value} and solve for n.",
306
+ ],
307
  )
308
 
309
+ m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
310
+ if m:
311
+ p = float(m.group(1))
312
+ n = float(m.group(2))
313
+ ans = p / 100.0 * n
314
+ answer_letter = _best_choice(ans, choices) if choices else None
315
+
316
+ return SolverResult(
317
+ domain="quant",
318
+ solved=True,
319
+ topic="percent",
320
+ answer_value=f"{ans:g}",
321
+ answer_letter=answer_letter,
322
+ internal_answer=f"{ans:g}",
323
+ steps=[
324
+ f"Convert {p}% to {p / 100:g}.",
325
+ f"Multiply by {n}.",
326
+ ],
 
327
  )
328
 
329
+ return None
330
+
331
+
332
+ def _solve_mean_median(text: str) -> Optional[SolverResult]:
333
+ lower = clean_math_text(text).lower()
334
+ nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)]
335
+ if not nums:
336
+ return None
337
+
338
+ if "mean" in lower or "average" in lower:
339
+ ans = mean(nums)
340
+ return SolverResult(
341
+ domain="quant",
342
+ solved=True,
343
+ topic="statistics",
344
+ answer_value=f"{ans:g}",
345
+ internal_answer=f"{ans:g}",
346
+ steps=["Add the values.", f"Divide by {len(nums)}."],
347
+ )
348
+
349
+ if "median" in lower:
350
+ ans = median(nums)
351
+ return SolverResult(
352
+ domain="quant",
353
+ solved=True,
354
+ topic="statistics",
355
+ answer_value=f"{ans:g}",
356
+ internal_answer=f"{ans:g}",
357
+ steps=["Order the values.", "Take the middle value."],
358
+ )
359
+
360
+ return None
361
+
362
+
363
+ def _solve_linear_equation(text: str) -> Optional[SolverResult]:
364
+ if sp is None:
365
+ return None
366
+
367
+ expr = _extract_equation(text)
368
+ if not expr:
369
+ return None
370
+
371
+ try:
372
+ lhs, rhs = expr.split("=", 1)
373
+ symbols = sorted(set(re.findall(r"\b[a-z]\b", expr)))
374
+ if not symbols:
375
+ return None
376
+
377
+ var_name = symbols[0]
378
+ var = sp.symbols(var_name)
379
+ sol = sp.solve(
380
+ sp.Eq(sp.sympify(_prepare_expression(lhs)), sp.sympify(_prepare_expression(rhs))),
381
+ var,
382
+ )
383
+ if not sol:
384
+ return None
385
+
386
+ value = sol[0]
387
+ try:
388
+ as_float = float(value)
389
+ except Exception:
390
+ as_float = None
391
+
392
+ choices = extract_choices(text)
393
+
394
+ return SolverResult(
395
+ domain="quant",
396
+ solved=True,
397
+ topic="algebra",
398
+ answer_value=str(value),
399
+ answer_letter=_best_choice(as_float, choices) if (as_float is not None and choices) else None,
400
+ internal_answer=f"{var_name} = {value}",
401
+ steps=[
402
+ "Treat the statement as an equation.",
403
+ "Undo operations on both sides to isolate the variable.",
404
+ f"That gives {var_name} = {value}.",
405
+ ],
406
+ )
407
+ except Exception:
408
+ return None
409
+
410
+
411
+ def solve_quant(text: str) -> SolverResult:
412
+ text = text or ""
413
+
414
+ for fn in (
415
+ _solve_successive_percent,
416
+ _solve_ratio_total,
417
+ _solve_remainder,
418
+ _solve_percent,
419
+ _solve_mean_median,
420
+ _solve_linear_equation,
421
+ ):
422
+ result = fn(text)
423
+ if result is not None:
424
+ return result
425
+
426
+ return SolverResult(
427
+ domain="quant",
428
+ solved=False,
429
+ topic="general_quant",
430
+ reply="This looks quantitative, but it does not match a strong rule-based pattern yet.",
431
+ steps=[
432
+ "Identify the quantity the question wants.",
433
+ "Translate the wording into an equation, ratio, or diagram.",
434
+ "Carry out the calculation carefully.",
435
+ ],
436
+ )