Spaces:
Sleeping
Sleeping
File size: 8,538 Bytes
ec4ae03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """
Question quality evaluator for curriculum-guided dual-task training.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Dict, List, Optional
from src.rl.question_classifier import QuestionClassifier
@dataclass
class QuestionEvalResult:
overall_score: float
topic_match: float
difficulty_score: float
clarity: float
solvability_score: float
novelty_combined: float
measured_difficulty: float
detected_topic: Dict[str, object]
novelty: Dict[str, float]
solvability: Dict[str, object]
def to_dict(self) -> Dict[str, object]:
return {
"overall_score": self.overall_score,
"topic_match": self.topic_match,
"difficulty_score": self.difficulty_score,
"clarity": self.clarity,
"solvability_score": self.solvability_score,
"novelty_combined": self.novelty_combined,
"measured_difficulty": self.measured_difficulty,
"detected_topic": self.detected_topic,
"novelty": self.novelty,
"solvability": self.solvability,
}
class QuestionQualityEvaluator:
"""Evaluate generated question quality for curriculum reward shaping."""
def __init__(
self,
reference_questions: Optional[List[str]] = None,
classifier: Optional[QuestionClassifier] = None,
novelty_window_size: int = 500, # raised from 100: 5 SP/iter β fills in ~100 iters
):
self.reference_questions = reference_questions or []
self.classifier = classifier or QuestionClassifier()
self.novelty_window_size = novelty_window_size
self.recent_questions: List[str] = []
# Pre-compute and cache reference n-gram sets once at init.
self._reference_ngrams = [self._extract_ngrams(q.lower()) for q in self.reference_questions]
# Rolling cache of n-gram sets for recent questions (avoids recomputing every call).
self._recent_ngrams: List[set] = []
def evaluate(
self,
question: str,
solution: str,
consensus_result: Optional[Dict[str, object]],
target_topic: str,
target_difficulty: float,
) -> Dict[str, object]:
detected_topic = self.classifier.classify_topic(question=question, solution=solution)
topic_match = self._topic_match_score(detected_topic, target_topic)
measured_difficulty = self.classifier.estimate_difficulty(
question=question,
solution=solution,
consensus_result=consensus_result,
)
difficulty_score = max(0.0, 1.0 - 2.0 * abs(measured_difficulty - target_difficulty))
clarity = self.classifier.check_clarity(question)
novelty = self.compute_novelty_score(question)
solvability = self.assess_solvability(question, solution, consensus_result)
overall = (
0.25 * topic_match
+ 0.15 * difficulty_score
+ 0.20 * clarity
+ 0.20 * float(solvability["score"])
+ 0.20 * novelty["combined"] # raised 0.10β0.20; taken from difficulty_score
)
return QuestionEvalResult(
overall_score=max(0.0, min(1.0, overall)),
topic_match=topic_match,
difficulty_score=difficulty_score,
clarity=clarity,
solvability_score=float(solvability["score"]),
novelty_combined=novelty["combined"],
measured_difficulty=measured_difficulty,
detected_topic=detected_topic,
novelty=novelty,
solvability=solvability,
).to_dict()
def compute_novelty_score(self, question: str) -> Dict[str, float]:
dataset_novelty = self._novelty_against_reference(question, self._reference_ngrams)
# Use cached recent n-gram sets instead of recomputing from strings each call (O(nΒ²)βO(n)).
session_novelty = self._novelty_against_reference(question, self._recent_ngrams)
# Weight dataset novelty higher (60%) β comparing against 8k GSM8K questions
# is a stable, meaningful signal. Session novelty (40%) guards against
# the model looping the same question template within a run.
combined = max(0.0, min(1.0, 0.60 * dataset_novelty + 0.40 * session_novelty))
self.recent_questions.append(question)
self.recent_questions = self.recent_questions[-self.novelty_window_size:]
# Keep n-gram cache in sync with the question window.
self._recent_ngrams.append(self._extract_ngrams(question.lower()))
self._recent_ngrams = self._recent_ngrams[-self.novelty_window_size:]
return {
"combined": combined,
"dataset_novelty": dataset_novelty,
"session_novelty": session_novelty,
}
def assess_solvability(
self,
question: str,
solution: str,
consensus_result: Optional[Dict[str, object]],
) -> Dict[str, object]:
q_lower = (question or "").lower()
has_numbers = bool(re.search(r"\d", q_lower))
has_question = ("?" in q_lower) or bool(re.search(
r"\b(find|calculate|how many|what is|determine|compute|evaluate|express|simplify|solve)\b",
q_lower,
))
length_ok = 8 <= len(q_lower.split()) <= 120
if not (has_numbers and has_question and length_ok):
return {"solvable": False, "reason": "syntactic_failure", "score": 0.0}
has_contradiction = bool(re.search(r"\b(impossible|cannot|undefined)\b", q_lower))
if has_contradiction:
return {"solvable": False, "reason": "semantic_failure", "score": 0.3}
# PRM-based arithmetic quality check (replaces SymPy step verification).
# consensus_strength = prm_mean: average PRM score across all reasoning steps.
# A low PRM mean means the model produced inconsistent or incorrect reasoning,
# which strongly signals the question is ambiguous, contradictory, or unsolvable.
# PRM understands full mathematical semantics β it catches errors that SymPy
# misses (e.g., wrong logic, incorrect setups) while not failing on valid prose.
if consensus_result:
confidence = float(consensus_result.get("consensus_strength", 0.5))
if confidence < 0.30:
# PRM rejects most steps β solution is invalid β question is likely unsolvable
return {"solvable": False, "reason": "low_prm_confidence", "score": 0.5}
if not bool(consensus_result.get("has_majority", False)):
# PRM is borderline (0.30β0.49) β uncertain solvability
return {"solvable": False, "reason": "no_consensus", "score": 0.6}
else:
confidence = 0.5
return {
"solvable": True,
"reason": "fully_solvable",
"score": 1.0,
"confidence": confidence,
}
@staticmethod
def _extract_ngrams(text: str, n: int = 3) -> set[str]:
normalized = re.sub(r"\s+", " ", (text or "").strip())
if len(normalized) < n:
return {normalized} if normalized else set()
return {normalized[i : i + n] for i in range(len(normalized) - n + 1)}
@staticmethod
def _jaccard_similarity(set1: set[str], set2: set[str]) -> float:
if not set1 or not set2:
return 0.0
union = set1 | set2
if not union:
return 0.0
return len(set1 & set2) / len(union)
def _novelty_against_reference(self, question: str, reference_sets: List[set[str]]) -> float:
if not reference_sets:
return 1.0
current = self._extract_ngrams((question or "").lower())
max_similarity = 0.0
for ref_set in reference_sets:
max_similarity = max(max_similarity, self._jaccard_similarity(current, ref_set))
return max(0.0, 1.0 - max_similarity)
@staticmethod
def _topic_match_score(detected_topic: Dict[str, object], target_topic: str) -> float:
primary = str(detected_topic.get("primary_topic", ""))
secondary = [str(x) for x in detected_topic.get("secondary_topics", [])]
confidence = float(detected_topic.get("confidence", 0.0))
if primary == target_topic:
return max(0.6, min(1.0, confidence))
if target_topic in secondary:
return max(0.4, min(0.8, confidence))
return min(0.35, confidence)
|