File size: 11,299 Bytes
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""
Process Reward Model (PRM) scorer for step-level correctness.

Uses Qwen/Qwen2.5-Math-PRM-7B β€” a purpose-built process reward model that
assigns each reasoning step a probability of being correct.  This replaces
the "consensus voting across three samples from the same policy" signal,
which was groupthink (three samples agree because they share the same
failure mode) and therefore uncorrelated with GSM8K accuracy.

How PRM scoring works
---------------------
* The input is ``question`` + an assistant response where each reasoning
  step is separated by the special token ``<extra_0>`` (also appended
  after the final step).
* The model runs a single forward pass and emits a classification logit
  (``[negative, positive]``) at every ``<extra_0>`` position.
* ``softmax`` β†’ the positive-class probability is the per-step reward in
  ``[0, 1]``.

Training integration
--------------------
Loaded once at startup alongside the policy.  Scored during rollout
``compute_reward`` calls (no gradient flow).  Quantise to 4-bit via
``bitsandbytes`` to keep VRAM under ~5 GB so there is ample headroom for
policy training on a single 80 GB A100.
"""

from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional

import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

from src.sft.solution_format import _step_bodies, extract_final_answer_numeric_str
from src.utils.attn_backend import select_attn_implementation

logger = logging.getLogger(__name__)


DEFAULT_SYSTEM_PROMPT = (
    "Please reason step by step, and put your final answer within \\boxed{}."
)
# Qwen PRM's step separator token.  Hard-coded by the model; do not change.
STEP_SEP_TOKEN = "<extra_0>"


def extract_prm_steps(solution: str) -> List[str]:
    """
    Split a Qwen-style ``Step N:`` solution into the text fragments the PRM
    expects β€” one element per reasoning step, with the final-answer line
    appended as a closing step so it gets its own correctness score.

    The ``Step N:`` prefix is stripped so we feed plain reasoning text
    (matches PRM's training distribution, which was Qwen-Math-Instruct
    paragraph-style outputs).
    """
    bodies = _step_bodies(solution)
    steps: List[str] = [b.strip() for b in bodies if b.strip()]
    final_raw = extract_final_answer_numeric_str(solution)
    if final_raw:
        steps.append(f"The answer is \\boxed{{{final_raw.strip()}}}")
    return steps


class ProcessRewardScorer:
    """
    Qwen2.5-Math-PRM-7B scorer.  Memory-efficient: the model is held in
    inference mode on the training device and runs in ``torch.no_grad``.
    """

    def __init__(
        self,
        model_name: str = "Qwen/Qwen2.5-Math-PRM-7B",
        device: Optional[torch.device] = None,
        load_in_4bit: bool = True,
        dtype: torch.dtype = torch.bfloat16,
        max_input_tokens: int = 4096,
    ):
        self.model_name = model_name
        self.device = device or torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.max_input_tokens = max_input_tokens

        logger.info(
            "Loading PRM %s (4-bit=%s, dtype=%s) on %s …",
            model_name, load_in_4bit, dtype, self.device,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True
        )

        load_kwargs: Dict[str, Any] = {
            "trust_remote_code": True,
            "torch_dtype": dtype,
            # PRM forward is eval-only but sequences can be 1-2k tokens
            # when the policy writes a lot of steps; flash-attn 2 cuts the
            # scoring forward by ~2x at those lengths.  Falls back to SDPA.
            "attn_implementation": select_attn_implementation(),
        }
        if load_in_4bit and torch.cuda.is_available():
            try:
                from transformers import BitsAndBytesConfig

                load_kwargs["quantization_config"] = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=dtype,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_use_double_quant=True,
                )
                load_kwargs["device_map"] = {"": self.device}
            except ImportError:
                logger.warning(
                    "bitsandbytes not available; falling back to bf16 PRM load"
                )
                load_in_4bit = False
        if not load_in_4bit:
            load_kwargs["device_map"] = {"": self.device}

        self.model = AutoModel.from_pretrained(model_name, **load_kwargs).eval()

        # Cache separator token id so we don't re-tokenize it every call.
        # encode() returns a list β€” PRM's step_sep is a single token.
        sep_ids = self.tokenizer.encode(STEP_SEP_TOKEN, add_special_tokens=False)
        if len(sep_ids) != 1:
            raise RuntimeError(
                f"PRM step separator {STEP_SEP_TOKEN!r} tokenized to "
                f"{sep_ids} (expected a single id).  Tokenizer mismatch."
            )
        self.step_sep_id = int(sep_ids[0])

        if torch.cuda.is_available():
            mem_alloc = torch.cuda.memory_allocated(self.device) / (1024 ** 3)
            logger.info(
                "PRM ready.  GPU memory allocated: %.2f GB  step_sep_id=%d",
                mem_alloc, self.step_sep_id,
            )

    @torch.no_grad()
    def score_solution(
        self,
        question: str,
        solution: str,
        system_prompt: str = DEFAULT_SYSTEM_PROMPT,
    ) -> Dict[str, Any]:
        """
        Return per-step correctness probabilities for ``solution``.

        Returns dict with:
            step_scores : List[float]  β€” per-step prob in [0, 1]
            num_steps   : int
            mean_score  : float        β€” avg across steps
            min_score   : float        β€” weakest step (error locator)
            final_score : float        β€” score on the answer-line step
            degraded    : bool         β€” True if we returned a zero-length
                                         score list (empty solution, etc.)
        """
        steps = extract_prm_steps(solution)
        if not steps:
            return {
                "step_scores": [],
                "num_steps": 0,
                "mean_score": 0.0,
                "min_score": 0.0,
                "final_score": 0.0,
                "degraded": True,
                "degraded_reason": "no extractable steps",
            }

        assistant_body = STEP_SEP_TOKEN.join(steps) + STEP_SEP_TOKEN
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question.strip()},
            {"role": "assistant", "content": assistant_body},
        ]
        try:
            prompt = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
        except Exception as exc:
            logger.warning("PRM chat template failed: %s", exc)
            return {
                "step_scores": [],
                "num_steps": len(steps),
                "mean_score": 0.0,
                "min_score": 0.0,
                "final_score": 0.0,
                "degraded": True,
                "degraded_reason": f"chat template error: {exc}",
            }

        enc = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_input_tokens,
        )
        input_ids = enc["input_ids"].to(self.device)
        attention_mask = enc.get("attention_mask")
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)

        try:
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        except Exception as exc:
            logger.warning("PRM forward pass failed: %s", exc)
            return {
                "step_scores": [],
                "num_steps": len(steps),
                "mean_score": 0.0,
                "min_score": 0.0,
                "final_score": 0.0,
                "degraded": True,
                "degraded_reason": f"forward error: {exc}",
            }

        logits = outputs[0]  # [1, seq_len, 2]
        token_mask = (input_ids == self.step_sep_id)  # [1, seq_len] bool

        # Follow the reference make_step_rewards routine.  We softmax the
        # logits, zero out non-separator positions, then read the positive
        # class (index 1) at each separator.
        probs = F.softmax(logits, dim=-1)  # [1, seq_len, 2]
        probs = probs * token_mask.unsqueeze(-1)
        sample = probs[0]  # [seq_len, 2]
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1]
        step_scores: List[float] = positive_probs.float().cpu().tolist()

        # Truncation may have dropped trailing separators.  Align lengths
        # conservatively by padding missing positions with the mean of what
        # we did see.  Log a warning so callers know the scores are partial.
        if len(step_scores) < len(steps) and step_scores:
            pad_val = float(sum(step_scores) / len(step_scores))
            n_padded = len(steps) - len(step_scores)
            step_scores = step_scores + [pad_val] * n_padded
            logger.warning(
                "PRM: %d/%d steps scored; %d tail step(s) padded with mean=%.3f "
                "(sequence likely truncated at %d tokens).",
                len(step_scores) - n_padded, len(steps), n_padded, pad_val,
                self.max_input_tokens,
            )
        elif len(step_scores) > len(steps):
            step_scores = step_scores[: len(steps)]

        if not step_scores:
            return {
                "step_scores": [],
                "num_steps": len(steps),
                "mean_score": 0.0,
                "min_score": 0.0,
                "final_score": 0.0,
                "degraded": True,
                "degraded_reason": "no separator token in output (truncated?)",
            }

        mean_score = float(sum(step_scores) / len(step_scores))
        min_score = float(min(step_scores))
        final_score = float(step_scores[-1])

        return {
            "step_scores": [float(s) for s in step_scores],
            "num_steps": len(step_scores),
            "mean_score": mean_score,
            "min_score": min_score,
            "final_score": final_score,
            "degraded": False,
            "padded_steps": len(step_scores) < len(steps),  # True if tail was padded
        }

    @torch.no_grad()
    def score_batch(
        self,
        items: List[Dict[str, str]],
        system_prompt: str = DEFAULT_SYSTEM_PROMPT,
    ) -> List[Dict[str, Any]]:
        """Score a list of ``{"question", "solution"}`` items sequentially.

        A proper padded batch path would be ~2-3Γ— faster but needs care to
        handle variable separator counts.  Sequential is simple, correct,
        and a single PRM forward takes ~100-300 ms on an A100 β€” acceptable
        overhead given self-play generation dominates rollout wall-time.
        """
        return [
            self.score_solution(it["question"], it["solution"], system_prompt)
            for it in items
        ]