"""ShareGPT + POLAR reward environment.""" from __future__ import annotations import json from http import HTTPStatus from pathlib import Path from typing import Any from datasets import Dataset, load_dataset import httpx import verifiers as vf from verifiers.types import Messages DEFAULT_SERVER = "wealth-intent-submissions-range.trycloudflare.com" DEFAULT_MODEL = "internlm/POLAR-7B" def _load_sharegpt_dataset(path: str | Path) -> Dataset: dataset = load_dataset("json", data_files=str(path), split="train") def to_single_turn(example: dict[str, Any]) -> dict[str, Any]: human_turn = next( turn["value"] for turn in example["conversations"] if turn["from"] == "human" ) assistant_turn = next( turn["value"] for turn in example["conversations"] if turn["from"] == "gpt" ) return { "prompt": [{"role": "user", "content": human_turn}], "info": { "reference": [{"role": "assistant", "content": assistant_turn}], }, } return dataset.map(to_single_turn, remove_columns=dataset.column_names) async def polar_reward( prompt: Messages, completion: Messages, info: dict[str, Any], reward_client: "PolarClient", **_: Any, ) -> float: assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"] if not assistant_turns: return 0.0 payload = [ { "prompt": prompt, "reference": info.get("reference", []), "output": [assistant_turns[-1]], } ] scores = await reward_client.score(payload) return float(scores[0]) if scores else 0.0 def load_environment( data_path: str | Path, *, server_address: str = DEFAULT_SERVER, reward_model: str = DEFAULT_MODEL, reward_scheme: type[vf.Rubric] | None = None, **env_kwargs: Any, ) -> vf.SingleTurnEnv: dataset = _load_sharegpt_dataset(data_path) client = PolarClient( base_url=f"https://{server_address}", model=reward_model, ) rubric_cls = reward_scheme or vf.Rubric rubric = rubric_cls(funcs=[polar_reward]) rubric.class_objects["reward_client"] = client return vf.SingleTurnEnv(dataset=dataset, rubric=rubric, **env_kwargs) class PolarClient: """Minimal async client for POLAR reward model served via vLLM.""" def __init__(self, *, base_url: str, model: str, timeout: float = 30.0, api_key: str | None = None): self.base_url = base_url.rstrip("/") self.model = model self.timeout = timeout self.api_key = api_key async def score(self, payload: list[dict[str, Any]]) -> list[float]: encoded = self._encode(payload) async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.post( f"{self.base_url}/v1/rewards", json={"model": self.model, "input": encoded}, headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else None, ) if response.status_code != HTTPStatus.OK: raise RuntimeError( f"POLAR reward request failed: {response.status_code} {response.text}" ) data = response.json() return data.get("rewards", []) @staticmethod def _encode(payload: list[dict[str, Any]]) -> list[dict[str, Any]]: # Ensure payload matches expected schema; keep implementation simple for now. return payload