# src/pytorch_debug_env/reward.py from __future__ import annotations from .bug_library import BUG_CATEGORIES EPSILON = 1e-2 def clamp_score(value: float) -> float: """Clamp scores to the open interval (0, 1) for validator compliance.""" return min(max(value, EPSILON), 1.0 - EPSILON) def hypothesis_quality(hypothesis: dict, ground_truth: dict) -> float: """Score how well the current hypothesis matches the ground truth.""" quality = 0.0 if hypothesis.get("affected_file") == ground_truth["primary_bug_file"]: quality += 0.45 elif hypothesis.get("affected_file") in ground_truth.get("related_files", []): quality += 0.15 if hypothesis.get("bug_type") == ground_truth["bug_type"]: quality += 0.40 elif BUG_CATEGORIES.get(hypothesis.get("bug_type")) == BUG_CATEGORIES.get(ground_truth["bug_type"]): quality += 0.13 calibration = 1.0 - abs(hypothesis.get("confidence", 0.5) - min(quality, 1.0)) quality += 0.15 * calibration return round(min(quality, 1.0), 4) def final_diagnosis_score(diagnosis: dict, ground_truth: dict) -> float: """Score the committed diagnosis against the ground truth.""" score = 0.0 if diagnosis.get("bug_type") == ground_truth["bug_type"]: score += 0.40 if diagnosis.get("affected_file") == ground_truth["primary_bug_file"]: score += 0.25 predicted = diagnosis.get("line_range", [0, 0]) actual = ground_truth.get("line_range", [0, 0]) overlap = line_overlap(predicted, actual) score += 0.20 * overlap if diagnosis.get("fix_strategy") == ground_truth["fix_strategy"]: score += 0.15 return round(clamp_score(min(score, 1.0)), 4) def line_overlap(pred: list[int], actual: list[int]) -> float: """Compute overlap ratio between two line ranges.""" p1, p2 = pred a1, a2 = actual inter = max(0, min(p2, a2) - max(p1, a1) + 1) union = max(p2, a2) - min(p1, a1) + 1 return inter / union if union else 0.0 def compute_step_reward( previous_quality: float, current_hypothesis: dict, ground_truth: dict, investigation_target: str | None = None, committed_diagnosis: dict | None = None, step_num: int = 1, max_steps: int = 5, ) -> tuple[float, dict]: """Compute step-level reward and diagnostic components.""" current_quality = hypothesis_quality(current_hypothesis, ground_truth) delta = current_quality - previous_quality confirmation_bonus = 0.03 * current_quality if abs(delta) < 0.01 else 0.0 investigation_reward = 0.0 if investigation_target: if investigation_target == ground_truth["primary_bug_file"]: investigation_reward = 0.07 elif investigation_target in ground_truth.get("related_files", []): investigation_reward = 0.025 elif investigation_target == ground_truth.get("red_herring_file"): investigation_reward = -0.04 else: investigation_reward = -0.01 diagnosis_reward = 0.0 if committed_diagnosis: diagnosis_reward = final_diagnosis_score(committed_diagnosis, ground_truth) if diagnosis_reward > 0.7: diagnosis_reward += max(0.0, 0.08 * (max_steps - step_num)) total = 0.60 * delta + 0.20 * investigation_reward + 0.20 * diagnosis_reward + confirmation_bonus total = round(clamp_score(min(max(total, 0.0), 1.0)), 4) return total, { "hypothesis_quality": current_quality, "hypothesis_delta": round(delta, 4), "investigation_reward": round(investigation_reward, 4), "diagnosis_reward": round(diagnosis_reward, 4), "confirmation_bonus": round(confirmation_bonus, 4), }