| """ShareGPT + POLAR reward environment.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any |
|
|
| from datasets import Dataset, load_dataset |
| import asyncio |
|
|
| import verifiers as vf |
| from verifiers.types import Messages |
| from xtuner.utils import RewardModelClient |
|
|
| DEFAULT_MODEL = "internlm/POLAR-7B" |
|
|
|
|
| def _load_sharegpt_dataset(path: str | Path) -> Dataset: |
| dataset = load_dataset("json", data_files=str(path), split="train") |
|
|
| def to_single_turn(example: dict[str, Any]) -> dict[str, Any]: |
| human_turn = next( |
| turn["value"] for turn in example["conversations"] if turn["from"] == "human" |
| ) |
| assistant_turn = next( |
| turn["value"] for turn in example["conversations"] if turn["from"] == "gpt" |
| ) |
| return { |
| "prompt": [{"role": "user", "content": human_turn}], |
| "info": { |
| "reference": [{"role": "assistant", "content": assistant_turn}], |
| }, |
| } |
|
|
| return dataset.map(to_single_turn, remove_columns=dataset.column_names) |
|
|
|
|
| class PoolingClient: |
| def __init__( |
| self, |
| model_path: str, |
| server_address: str, |
| server_type: str = "lmdeploy", |
| max_length: int = 16384, |
| max_response_length: int = 4096, |
| response_cut_side: str = "left", |
| ): |
| self.client = RewardModelClient( |
| model_path, |
| max_length=max_length, |
| max_response_length=max_response_length, |
| response_cut_side=response_cut_side, |
| server_type=server_type, |
| server_address=server_address, |
| ) |
|
|
| def encode(self, sample: dict[str, Any]) -> str: |
| prompt_text = "\n".join( |
| message["content"] for message in sample.get("prompt", []) |
| ) |
| reference_text = "\n".join( |
| message["content"] for message in sample.get("reference", []) |
| ) |
| output_text = "\n".join( |
| message["content"] for message in sample.get("output", []) |
| ) |
| return f"{prompt_text}\n{reference_text}<|reward|>{prompt_text}\n{output_text}[UNUSED_TOKEN_130]" |
|
|
| def score(self, payload: list[dict[str, Any]]) -> list[float]: |
| encoded_payload = [self.encode(item) for item in payload] |
| rewards = self.client.lmdeploy_request_reward(encoded_payload) |
| if rewards is None: |
| raise RuntimeError("Failed to get rewards from lmdeploy server") |
| return rewards |
|
|
|
|
| async def polar_reward( |
| prompt: Messages, |
| completion: Messages, |
| info: dict[str, Any], |
| reward_client: PoolingClient, |
| pooling_semaphore: asyncio.Semaphore, |
| **_: Any, |
| ) -> float: |
| assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"] |
| if not assistant_turns: |
| return 0.0 |
|
|
| payload = [ |
| { |
| "prompt": prompt, |
| "reference": info.get("reference", []), |
| "output": [assistant_turns[-1]], |
| } |
| ] |
| async with pooling_semaphore: |
| loop = asyncio.get_running_loop() |
| rewards = await loop.run_in_executor(None, reward_client.score, payload) |
| if rewards: |
| return float(rewards[-1]) * 10.0 |
| raise RuntimeError(f"Unexpected reward response: {rewards}") |
|
|
|
|
| def load_environment( |
| data_path: str | Path, |
| *, |
| server_address: str, |
| reward_model: str = DEFAULT_MODEL, |
| reward_scheme: type[vf.Rubric] | None = None, |
| server_type: str = "lmdeploy", |
| **env_kwargs: Any, |
| ) -> vf.SingleTurnEnv: |
| dataset = _load_sharegpt_dataset(data_path) |
|
|
| client = PoolingClient( |
| model_path=reward_model, |
| server_address=server_address, |
| server_type=server_type, |
| ) |
|
|
| rubric_cls = reward_scheme or vf.Rubric |
| rubric = rubric_cls(funcs=[polar_reward]) |
| rubric.class_objects["reward_client"] = client |
| rubric.class_objects.setdefault("pooling_semaphore", asyncio.Semaphore(4)) |
|
|
| return vf.SingleTurnEnv(dataset=dataset, rubric=rubric, **env_kwargs) |
|
|
|
|
|
|
|
|