AxiomForgeAI / src /rl /prm_scorer.py
jampuramprem's picture
Initial Space deployment
ec4ae03
"""
Process Reward Model (PRM) scorer for step-level correctness.
Uses Qwen/Qwen2.5-Math-PRM-7B β€” a purpose-built process reward model that
assigns each reasoning step a probability of being correct. This replaces
the "consensus voting across three samples from the same policy" signal,
which was groupthink (three samples agree because they share the same
failure mode) and therefore uncorrelated with GSM8K accuracy.
How PRM scoring works
---------------------
* The input is ``question`` + an assistant response where each reasoning
step is separated by the special token ``<extra_0>`` (also appended
after the final step).
* The model runs a single forward pass and emits a classification logit
(``[negative, positive]``) at every ``<extra_0>`` position.
* ``softmax`` β†’ the positive-class probability is the per-step reward in
``[0, 1]``.
Training integration
--------------------
Loaded once at startup alongside the policy. Scored during rollout
``compute_reward`` calls (no gradient flow). Quantise to 4-bit via
``bitsandbytes`` to keep VRAM under ~5 GB so there is ample headroom for
policy training on a single 80 GB A100.
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from src.sft.solution_format import _step_bodies, extract_final_answer_numeric_str
from src.utils.attn_backend import select_attn_implementation
logger = logging.getLogger(__name__)
DEFAULT_SYSTEM_PROMPT = (
"Please reason step by step, and put your final answer within \\boxed{}."
)
# Qwen PRM's step separator token. Hard-coded by the model; do not change.
STEP_SEP_TOKEN = "<extra_0>"
def extract_prm_steps(solution: str) -> List[str]:
"""
Split a Qwen-style ``Step N:`` solution into the text fragments the PRM
expects β€” one element per reasoning step, with the final-answer line
appended as a closing step so it gets its own correctness score.
The ``Step N:`` prefix is stripped so we feed plain reasoning text
(matches PRM's training distribution, which was Qwen-Math-Instruct
paragraph-style outputs).
"""
bodies = _step_bodies(solution)
steps: List[str] = [b.strip() for b in bodies if b.strip()]
final_raw = extract_final_answer_numeric_str(solution)
if final_raw:
steps.append(f"The answer is \\boxed{{{final_raw.strip()}}}")
return steps
class ProcessRewardScorer:
"""
Qwen2.5-Math-PRM-7B scorer. Memory-efficient: the model is held in
inference mode on the training device and runs in ``torch.no_grad``.
"""
def __init__(
self,
model_name: str = "Qwen/Qwen2.5-Math-PRM-7B",
device: Optional[torch.device] = None,
load_in_4bit: bool = True,
dtype: torch.dtype = torch.bfloat16,
max_input_tokens: int = 4096,
):
self.model_name = model_name
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.max_input_tokens = max_input_tokens
logger.info(
"Loading PRM %s (4-bit=%s, dtype=%s) on %s …",
model_name, load_in_4bit, dtype, self.device,
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
load_kwargs: Dict[str, Any] = {
"trust_remote_code": True,
"torch_dtype": dtype,
# PRM forward is eval-only but sequences can be 1-2k tokens
# when the policy writes a lot of steps; flash-attn 2 cuts the
# scoring forward by ~2x at those lengths. Falls back to SDPA.
"attn_implementation": select_attn_implementation(),
}
if load_in_4bit and torch.cuda.is_available():
try:
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=dtype,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
load_kwargs["device_map"] = {"": self.device}
except ImportError:
logger.warning(
"bitsandbytes not available; falling back to bf16 PRM load"
)
load_in_4bit = False
if not load_in_4bit:
load_kwargs["device_map"] = {"": self.device}
self.model = AutoModel.from_pretrained(model_name, **load_kwargs).eval()
# Cache separator token id so we don't re-tokenize it every call.
# encode() returns a list β€” PRM's step_sep is a single token.
sep_ids = self.tokenizer.encode(STEP_SEP_TOKEN, add_special_tokens=False)
if len(sep_ids) != 1:
raise RuntimeError(
f"PRM step separator {STEP_SEP_TOKEN!r} tokenized to "
f"{sep_ids} (expected a single id). Tokenizer mismatch."
)
self.step_sep_id = int(sep_ids[0])
if torch.cuda.is_available():
mem_alloc = torch.cuda.memory_allocated(self.device) / (1024 ** 3)
logger.info(
"PRM ready. GPU memory allocated: %.2f GB step_sep_id=%d",
mem_alloc, self.step_sep_id,
)
@torch.no_grad()
def score_solution(
self,
question: str,
solution: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
) -> Dict[str, Any]:
"""
Return per-step correctness probabilities for ``solution``.
Returns dict with:
step_scores : List[float] β€” per-step prob in [0, 1]
num_steps : int
mean_score : float β€” avg across steps
min_score : float β€” weakest step (error locator)
final_score : float β€” score on the answer-line step
degraded : bool β€” True if we returned a zero-length
score list (empty solution, etc.)
"""
steps = extract_prm_steps(solution)
if not steps:
return {
"step_scores": [],
"num_steps": 0,
"mean_score": 0.0,
"min_score": 0.0,
"final_score": 0.0,
"degraded": True,
"degraded_reason": "no extractable steps",
}
assistant_body = STEP_SEP_TOKEN.join(steps) + STEP_SEP_TOKEN
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question.strip()},
{"role": "assistant", "content": assistant_body},
]
try:
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
except Exception as exc:
logger.warning("PRM chat template failed: %s", exc)
return {
"step_scores": [],
"num_steps": len(steps),
"mean_score": 0.0,
"min_score": 0.0,
"final_score": 0.0,
"degraded": True,
"degraded_reason": f"chat template error: {exc}",
}
enc = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.max_input_tokens,
)
input_ids = enc["input_ids"].to(self.device)
attention_mask = enc.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(self.device)
try:
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
except Exception as exc:
logger.warning("PRM forward pass failed: %s", exc)
return {
"step_scores": [],
"num_steps": len(steps),
"mean_score": 0.0,
"min_score": 0.0,
"final_score": 0.0,
"degraded": True,
"degraded_reason": f"forward error: {exc}",
}
logits = outputs[0] # [1, seq_len, 2]
token_mask = (input_ids == self.step_sep_id) # [1, seq_len] bool
# Follow the reference make_step_rewards routine. We softmax the
# logits, zero out non-separator positions, then read the positive
# class (index 1) at each separator.
probs = F.softmax(logits, dim=-1) # [1, seq_len, 2]
probs = probs * token_mask.unsqueeze(-1)
sample = probs[0] # [seq_len, 2]
positive_probs = sample[sample != 0].view(-1, 2)[:, 1]
step_scores: List[float] = positive_probs.float().cpu().tolist()
# Truncation may have dropped trailing separators. Align lengths
# conservatively by padding missing positions with the mean of what
# we did see. Log a warning so callers know the scores are partial.
if len(step_scores) < len(steps) and step_scores:
pad_val = float(sum(step_scores) / len(step_scores))
n_padded = len(steps) - len(step_scores)
step_scores = step_scores + [pad_val] * n_padded
logger.warning(
"PRM: %d/%d steps scored; %d tail step(s) padded with mean=%.3f "
"(sequence likely truncated at %d tokens).",
len(step_scores) - n_padded, len(steps), n_padded, pad_val,
self.max_input_tokens,
)
elif len(step_scores) > len(steps):
step_scores = step_scores[: len(steps)]
if not step_scores:
return {
"step_scores": [],
"num_steps": len(steps),
"mean_score": 0.0,
"min_score": 0.0,
"final_score": 0.0,
"degraded": True,
"degraded_reason": "no separator token in output (truncated?)",
}
mean_score = float(sum(step_scores) / len(step_scores))
min_score = float(min(step_scores))
final_score = float(step_scores[-1])
return {
"step_scores": [float(s) for s in step_scores],
"num_steps": len(step_scores),
"mean_score": mean_score,
"min_score": min_score,
"final_score": final_score,
"degraded": False,
"padded_steps": len(step_scores) < len(steps), # True if tail was padded
}
@torch.no_grad()
def score_batch(
self,
items: List[Dict[str, str]],
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
) -> List[Dict[str, Any]]:
"""Score a list of ``{"question", "solution"}`` items sequentially.
A proper padded batch path would be ~2-3Γ— faster but needs care to
handle variable separator counts. Sequential is simple, correct,
and a single PRM forward takes ~100-300 ms on an A100 β€” acceptable
overhead given self-play generation dominates rollout wall-time.
"""
return [
self.score_solution(it["question"], it["solution"], system_prompt)
for it in items
]