Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """LLM-as-a-judge rubric for reward computation. | |
| Uses an LLM endpoint (via LLMClient) to evaluate agent actions/observations. | |
| Usage: | |
| client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") | |
| judge = LLMJudge( | |
| prompt_template="Rate this code solution:\\n{action}\\n\\nScore (0-1):", | |
| client=client, | |
| ) | |
| score = await judge(action, observation) | |
| See RFC 004 for full design: rfcs/004-rubrics.md | |
| """ | |
| import re | |
| from typing import Any, Dict | |
| from openenv.core.llm_client import LLMClient | |
| from openenv.core.rubrics.base import Rubric | |
| class LLMJudge(Rubric): | |
| """Rubric that uses an LLM to evaluate agent actions/observations. | |
| The prompt template is formatted with ``{action}`` and ``{observation}`` | |
| placeholders. The LLM response is parsed for a numeric score. | |
| Args: | |
| prompt_template: Template string with {action} and {observation} placeholders. | |
| client: An LLMClient instance for making LLM calls. | |
| score_pattern: Regex to extract the score from the LLM response. | |
| Defaults to matching the first decimal number. | |
| default_score: Score returned when parsing fails. | |
| normalize: If True, clamp extracted score to [0, 1]. | |
| """ | |
| def __init__( | |
| self, | |
| prompt_template: str, | |
| client: LLMClient, | |
| *, | |
| score_pattern: str | None = None, | |
| default_score: float = 0.0, | |
| normalize: bool = True, | |
| ): | |
| super().__init__() | |
| self.prompt_template = prompt_template | |
| self._client = client | |
| self._score_pattern = re.compile(score_pattern or r"(\d+\.?\d*)") | |
| self.default_score = default_score | |
| self.normalize = normalize | |
| async def forward(self, action: Any, observation: Any) -> float: | |
| """Evaluate by sending a prompt to the LLM and parsing the score. | |
| Args: | |
| action: The action taken by the agent. | |
| observation: The resulting observation. | |
| Returns: | |
| Parsed score from the LLM response. | |
| """ | |
| prompt = self._render_prompt(action, observation) | |
| response = await self._client.complete(prompt) | |
| return self._parse_score(response) | |
| def _render_prompt(self, action: Any, observation: Any) -> str: | |
| """Format the prompt template with action and observation. | |
| Override in subclasses for custom prompt construction. | |
| """ | |
| return self.prompt_template.format(action=action, observation=observation) | |
| def _parse_score(self, response: str) -> float: | |
| """Extract a numeric score from the LLM response. | |
| Uses the configured regex pattern to find the first match. | |
| Returns default_score if no match is found. | |
| """ | |
| match = self._score_pattern.search(response) | |
| if match is None: | |
| return self.default_score | |
| try: | |
| # Use first capture group if present, otherwise full match | |
| text = match.group(1) if match.lastindex else match.group(0) | |
| score = float(text) | |
| except (ValueError, IndexError): | |
| return self.default_score | |
| if self.normalize: | |
| score = max(0.0, min(1.0, score)) | |
| return score | |
| def state_dict(self) -> Dict[str, Any]: | |
| """Serialize rubric configuration.""" | |
| return { | |
| "prompt_template": self.prompt_template, | |
| "score_pattern": self._score_pattern.pattern, | |
| "default_score": self.default_score, | |
| "normalize": self.normalize, | |
| } | |
| def load_state_dict(self, state: Dict[str, Any]) -> None: | |
| """Load rubric configuration from checkpoint.""" | |
| if "prompt_template" in state: | |
| self.prompt_template = state["prompt_template"] | |
| if "score_pattern" in state: | |
| self._score_pattern = re.compile(state["score_pattern"]) | |
| if "default_score" in state: | |
| self.default_score = state["default_score"] | |
| if "normalize" in state: | |
| self.normalize = state["normalize"] | |