File size: 8,538 Bytes
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Question quality evaluator for curriculum-guided dual-task training.
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Dict, List, Optional

from src.rl.question_classifier import QuestionClassifier


@dataclass
class QuestionEvalResult:
    overall_score: float
    topic_match: float
    difficulty_score: float
    clarity: float
    solvability_score: float
    novelty_combined: float
    measured_difficulty: float
    detected_topic: Dict[str, object]
    novelty: Dict[str, float]
    solvability: Dict[str, object]

    def to_dict(self) -> Dict[str, object]:
        return {
            "overall_score": self.overall_score,
            "topic_match": self.topic_match,
            "difficulty_score": self.difficulty_score,
            "clarity": self.clarity,
            "solvability_score": self.solvability_score,
            "novelty_combined": self.novelty_combined,
            "measured_difficulty": self.measured_difficulty,
            "detected_topic": self.detected_topic,
            "novelty": self.novelty,
            "solvability": self.solvability,
        }


class QuestionQualityEvaluator:
    """Evaluate generated question quality for curriculum reward shaping."""

    def __init__(
        self,
        reference_questions: Optional[List[str]] = None,
        classifier: Optional[QuestionClassifier] = None,
        novelty_window_size: int = 500,   # raised from 100: 5 SP/iter β†’ fills in ~100 iters
    ):
        self.reference_questions = reference_questions or []
        self.classifier = classifier or QuestionClassifier()
        self.novelty_window_size = novelty_window_size
        self.recent_questions: List[str] = []
        # Pre-compute and cache reference n-gram sets once at init.
        self._reference_ngrams = [self._extract_ngrams(q.lower()) for q in self.reference_questions]
        # Rolling cache of n-gram sets for recent questions (avoids recomputing every call).
        self._recent_ngrams: List[set] = []

    def evaluate(
        self,
        question: str,
        solution: str,
        consensus_result: Optional[Dict[str, object]],
        target_topic: str,
        target_difficulty: float,
    ) -> Dict[str, object]:
        detected_topic = self.classifier.classify_topic(question=question, solution=solution)
        topic_match = self._topic_match_score(detected_topic, target_topic)

        measured_difficulty = self.classifier.estimate_difficulty(
            question=question,
            solution=solution,
            consensus_result=consensus_result,
        )
        difficulty_score = max(0.0, 1.0 - 2.0 * abs(measured_difficulty - target_difficulty))

        clarity = self.classifier.check_clarity(question)
        novelty = self.compute_novelty_score(question)
        solvability = self.assess_solvability(question, solution, consensus_result)

        overall = (
            0.25 * topic_match
            + 0.15 * difficulty_score
            + 0.20 * clarity
            + 0.20 * float(solvability["score"])
            + 0.20 * novelty["combined"]   # raised 0.10β†’0.20; taken from difficulty_score
        )

        return QuestionEvalResult(
            overall_score=max(0.0, min(1.0, overall)),
            topic_match=topic_match,
            difficulty_score=difficulty_score,
            clarity=clarity,
            solvability_score=float(solvability["score"]),
            novelty_combined=novelty["combined"],
            measured_difficulty=measured_difficulty,
            detected_topic=detected_topic,
            novelty=novelty,
            solvability=solvability,
        ).to_dict()

    def compute_novelty_score(self, question: str) -> Dict[str, float]:
        dataset_novelty = self._novelty_against_reference(question, self._reference_ngrams)
        # Use cached recent n-gram sets instead of recomputing from strings each call (O(nΒ²)β†’O(n)).
        session_novelty = self._novelty_against_reference(question, self._recent_ngrams)
        # Weight dataset novelty higher (60%) β€” comparing against 8k GSM8K questions
        # is a stable, meaningful signal. Session novelty (40%) guards against
        # the model looping the same question template within a run.
        combined = max(0.0, min(1.0, 0.60 * dataset_novelty + 0.40 * session_novelty))

        self.recent_questions.append(question)
        self.recent_questions = self.recent_questions[-self.novelty_window_size:]
        # Keep n-gram cache in sync with the question window.
        self._recent_ngrams.append(self._extract_ngrams(question.lower()))
        self._recent_ngrams = self._recent_ngrams[-self.novelty_window_size:]

        return {
            "combined": combined,
            "dataset_novelty": dataset_novelty,
            "session_novelty": session_novelty,
        }

    def assess_solvability(
        self,
        question: str,
        solution: str,
        consensus_result: Optional[Dict[str, object]],
    ) -> Dict[str, object]:
        q_lower = (question or "").lower()
        has_numbers = bool(re.search(r"\d", q_lower))
        has_question = ("?" in q_lower) or bool(re.search(
            r"\b(find|calculate|how many|what is|determine|compute|evaluate|express|simplify|solve)\b",
            q_lower,
        ))
        length_ok = 8 <= len(q_lower.split()) <= 120
        if not (has_numbers and has_question and length_ok):
            return {"solvable": False, "reason": "syntactic_failure", "score": 0.0}

        has_contradiction = bool(re.search(r"\b(impossible|cannot|undefined)\b", q_lower))
        if has_contradiction:
            return {"solvable": False, "reason": "semantic_failure", "score": 0.3}

        # PRM-based arithmetic quality check (replaces SymPy step verification).
        # consensus_strength = prm_mean: average PRM score across all reasoning steps.
        # A low PRM mean means the model produced inconsistent or incorrect reasoning,
        # which strongly signals the question is ambiguous, contradictory, or unsolvable.
        # PRM understands full mathematical semantics β€” it catches errors that SymPy
        # misses (e.g., wrong logic, incorrect setups) while not failing on valid prose.
        if consensus_result:
            confidence = float(consensus_result.get("consensus_strength", 0.5))
            if confidence < 0.30:
                # PRM rejects most steps β†’ solution is invalid β†’ question is likely unsolvable
                return {"solvable": False, "reason": "low_prm_confidence", "score": 0.5}
            if not bool(consensus_result.get("has_majority", False)):
                # PRM is borderline (0.30–0.49) β†’ uncertain solvability
                return {"solvable": False, "reason": "no_consensus", "score": 0.6}
        else:
            confidence = 0.5

        return {
            "solvable": True,
            "reason": "fully_solvable",
            "score": 1.0,
            "confidence": confidence,
        }

    @staticmethod
    def _extract_ngrams(text: str, n: int = 3) -> set[str]:
        normalized = re.sub(r"\s+", " ", (text or "").strip())
        if len(normalized) < n:
            return {normalized} if normalized else set()
        return {normalized[i : i + n] for i in range(len(normalized) - n + 1)}

    @staticmethod
    def _jaccard_similarity(set1: set[str], set2: set[str]) -> float:
        if not set1 or not set2:
            return 0.0
        union = set1 | set2
        if not union:
            return 0.0
        return len(set1 & set2) / len(union)

    def _novelty_against_reference(self, question: str, reference_sets: List[set[str]]) -> float:
        if not reference_sets:
            return 1.0
        current = self._extract_ngrams((question or "").lower())
        max_similarity = 0.0
        for ref_set in reference_sets:
            max_similarity = max(max_similarity, self._jaccard_similarity(current, ref_set))
        return max(0.0, 1.0 - max_similarity)

    @staticmethod
    def _topic_match_score(detected_topic: Dict[str, object], target_topic: str) -> float:
        primary = str(detected_topic.get("primary_topic", ""))
        secondary = [str(x) for x in detected_topic.get("secondary_topics", [])]
        confidence = float(detected_topic.get("confidence", 0.0))
        if primary == target_topic:
            return max(0.6, min(1.0, confidence))
        if target_topic in secondary:
            return max(0.4, min(0.8, confidence))
        return min(0.35, confidence)