| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Core functions to implement PPO algorithms. |
| The function implemented in this file should be used by trainer with different distributed strategies to |
| implement PPO |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from collections import defaultdict |
| from enum import Enum |
| from typing import TYPE_CHECKING, Any, Literal |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from ..utils import torch_functional as VF |
|
|
|
|
| if TYPE_CHECKING: |
| from .config import AlgorithmConfig |
|
|
|
|
| class KLController(ABC): |
| kl_coef: float |
| """KL coefficient.""" |
|
|
| @abstractmethod |
| def update(self, current_kl: float, n_steps: int): |
| """Update kl_coef according to current KL.""" |
| ... |
|
|
|
|
| class AdaptiveKLController(KLController): |
| """Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf |
| |
| Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L54""" |
|
|
| def __init__(self, init_kl_coef: float, target_kl: float, horizon: float): |
| self.kl_coef = init_kl_coef |
| self.target = target_kl |
| self.horizon = horizon |
|
|
| def update(self, current_kl: float, n_steps: int): |
| target = self.target |
| proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) |
| mult = 1 + proportional_error * n_steps / self.horizon |
| self.kl_coef *= mult |
|
|
|
|
| class FixedKLController(KLController): |
| """Fixed KL controller. |
| |
| Copeid from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L72""" |
|
|
| def __init__(self, init_kl_coef: float): |
| self.kl_coef = init_kl_coef |
|
|
| def update(self, current_kl: float, n_steps: int): |
| pass |
|
|
|
|
| class AdvantageEstimator(str, Enum): |
| """ |
| Using an enumeration class to avoid spelling errors in adv_estimator |
| """ |
|
|
| GAE = "gae" |
| GRPO = "grpo" |
| REINFORCE_PLUS_PLUS = "reinforce_plus_plus" |
| REMAX = "remax" |
| RLOO = "rloo" |
| EMA_GRPO = "ema_grpo" |
|
|
|
|
| ADV_ESTIMATOR_MAP: dict[str, Any] = {} |
|
|
|
|
| def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController: |
| """Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319""" |
| if algorithm_config.kl_type == "fixed": |
| kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef) |
| elif algorithm_config.kl_type == "adaptive": |
| assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}." |
| kl_ctrl = AdaptiveKLController( |
| init_kl_coef=algorithm_config.kl_coef, |
| target_kl=algorithm_config.kl_target, |
| horizon=algorithm_config.kl_horizon, |
| ) |
| else: |
| raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.") |
|
|
| return kl_ctrl |
|
|
|
|
| def register_adv_estimator(name: AdvantageEstimator): |
| """Decorator to register a advantage estimator function with a given name.""" |
|
|
| def decorator(fn): |
| wrapped_fn = torch.no_grad()(fn) |
| ADV_ESTIMATOR_MAP[getattr(name, "value", name)] = wrapped_fn |
| return wrapped_fn |
|
|
| return decorator |
|
|
|
|
| def compute_advantage_return(name: AdvantageEstimator, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: |
| """Compute advantage and return for a given advantage estimator.""" |
| return ADV_ESTIMATOR_MAP[getattr(name, "value", name)](**kwargs) |
|
|
|
|
| @register_adv_estimator(AdvantageEstimator.GAE) |
| def compute_gae_advantage_return( |
| token_level_rewards: torch.Tensor, |
| values: torch.Tensor, |
| response_mask: torch.Tensor, |
| gamma: torch.Tensor, |
| lam: torch.Tensor, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513 |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| values: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length). The token after eos tokens have mask zero. |
| gamma: `(float)` |
| discounted factor used in RL |
| lam: `(float)` |
| lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| |
| """ |
| lastgaelam = 0 |
| advantages_reversed = [] |
| gen_len = token_level_rewards.shape[-1] |
| for t in reversed(range(gen_len)): |
| nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 |
| delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] |
| lastgaelam = delta + gamma * lam * lastgaelam |
| advantages_reversed.append(lastgaelam) |
|
|
| advantages = torch.stack(advantages_reversed[::-1], dim=1) |
| returns = advantages + values |
| advantages = VF.masked_whiten(advantages, response_mask) |
| return advantages, returns |
|
|
|
|
| |
| @register_adv_estimator(AdvantageEstimator.GRPO) |
| def compute_grpo_outcome_advantage( |
| token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6, **kwargs |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response). |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| index: `(torch.Tensor)` |
| shape: (bs,) |
| eps: `(float)` |
| epsilon value to avoid division by zero |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| |
| """ |
| scores = token_level_rewards.sum(dim=-1) |
| id2score = defaultdict(list) |
| id2mean, id2std = {}, {} |
|
|
| bsz = scores.shape[0] |
| for i in range(bsz): |
| id2score[index[i]].append(scores[i]) |
|
|
| for idx in id2score: |
| assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1." |
| id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) |
| id2std[idx] = torch.std(torch.tensor(id2score[idx])) |
|
|
| for i in range(bsz): |
| scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps) |
|
|
| returns = scores.unsqueeze(-1) * response_mask |
| return returns, returns |
|
|
|
|
|
|
|
|
|
|
|
|
| class _EmaStdTracker: |
| """ |
| Maintain EMA first moment E[X] and second moment E[X^2] for each "task key", |
| and compute std = sqrt(E[X^2] - (E[X])^2). |
| - Here X refers to the "Outcome score" (a single scalar obtained by summing token-level rewards over time). |
| - The statistics are persistent only within a single process / rank; in multi-process settings, |
| each process maintains its own tracker (for global consistency, use all-reduce or save to trainer state externally). |
| - Benefit of EMA: adaptive to non-stationary distributions; meanwhile avoids over-reliance on early statistics |
| (the effective memory length is controlled by decay). |
| """ |
| def __init__(self, decay: float = 0.99, min_std: float = 1e-3): |
| self.decay = decay |
| self.min_std = min_std |
| |
| self.state: dict[str, dict[str, float]] = {} |
|
|
| def get_std(self, key: str) -> float: |
| """ |
| Get the current EMA standard deviation of the given task. |
| If uninitialized, return min_std (a conservative value to avoid explosion / NaN at the beginning). |
| """ |
| s = self.state.get(key) |
| if not s or not s.get("initialized", False): |
| return self.min_std |
| |
| var = max(s["m2"] - s["m1"] * s["m1"], 0.0) |
| |
| return float(max(var, self.min_std ** 2) ** 0.5) |
|
|
| def update_with_batch_scores(self, key: str, scores: torch.Tensor): |
| """ |
| Update EMA first and second moments using the "current batch of outcome scores" for this task. |
| Convention: normalization uses the *old EMA values*, and EMA is updated *afterward* using this batch, |
| to avoid bias caused by "seeing itself". |
| """ |
| if scores.numel() == 0: |
| return |
| x = scores.float() |
| |
| |
| m1_batch = x.mean().item() |
| m2_batch = (x * x).mean().item() |
|
|
| s = self.state.get(key) |
| if s is None or not s.get("initialized", False): |
| |
| self.state[key] = {"m1": m1_batch, "m2": m2_batch, "initialized": True} |
| else: |
| d = self.decay |
| s["m1"] = d * s["m1"] + (1.0 - d) * m1_batch |
| s["m2"] = d * s["m2"] + (1.0 - d) * m2_batch |
| s["initialized"] = True |
|
|
| |
| def get_mean(self, key: str) -> float: |
| """ |
| Return the current EMA mean (E[X]) of the task. Return 0.0 if uninitialized. |
| (Used only for logging / monitoring; not involved in algorithmic scaling; no clipping applied.) |
| """ |
| s = self.state.get(key) |
| if not s or not s.get("initialized", False): |
| return 0.0 |
| return float(s["m1"]) |
|
|
| |
| def log_means_and_stds( |
| self, |
| keys: list[str] | None = None, |
| print_fn=None, |
| tag_means: str = "EMA_MEAN", |
| tag_stds: str = "EMA_STD", |
| digits: int = 6, |
| sort_keys: bool = True, |
| ) -> None: |
| """ |
| Print in two lines: |
| [EMA_MEAN] task=mean:... |
| [EMA_STD] task=std:... |
| """ |
| if print_fn is None: |
| print_fn = print |
| if keys is None: |
| keys = list(self.state.keys()) |
| if sort_keys: |
| keys = sorted(keys) |
| if not keys: |
| try: |
| print_fn(f"[{tag_means}] (empty)") |
| print_fn(f"[{tag_stds}] (empty)") |
| except Exception: |
| pass |
| return |
|
|
| mean_line = "[{tag}] ".format(tag=tag_means) + ", ".join( |
| f"{k}=mean:{self.get_mean(k):.{digits}f}" for k in keys |
| ) |
| std_line = "[{tag}] ".format(tag=tag_stds) + ", ".join( |
| f"{k}=std:{self.get_std(k):.{digits}f}" for k in keys |
| ) |
| try: |
| print_fn(mean_line) |
| print_fn(std_line) |
| except Exception: |
| |
| pass |
|
|
|
|
|
|
|
|
|
|
| |
| _EMA_STD_TRACKER = _EmaStdTracker() |
|
|
|
|
| |
| def _task_key_of(sample_problem_type: str, sample_data_type: str | None) -> str: |
| """ |
| Task partition rule: |
| - By default, aggregate by problem_type; |
| - If problem_type == "segmentation", further split by data_type into |
| "segmentation/image" and "segmentation/video". |
| """ |
| if sample_problem_type == "segmentation": |
| dt = (sample_data_type or "").lower() |
| if dt in ("video", "image"): |
| return f"segmentation/{dt}" |
| return sample_problem_type |
|
|
|
|
|
|
| @register_adv_estimator(AdvantageEstimator.EMA_GRPO) |
| def compute_ema_grpo_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index, |
| problem_type, |
| data_type=None, |
| |
| ema_decay: float = 0.99, |
| min_std: float = 1e-3, |
| use_group_mean_centering: bool = True, |
| eps: float = 1e-6, |
| |
| guard_abs_max: float = 5.0, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| EMA-GRPO (Outcome supervision): |
| 1) Keep GRPO-style group-wise mean-centering (samples with the same index are centered), |
| to reduce intra-group variance; |
| 2) No longer use "group std" for scaling; instead use "task-level EMA std"; |
| 3) **First update EMA using current batch scores (including initialization), |
| then scale using the "updated" std**, |
| i.e., statistics that have "seen themselves", improving responsiveness to non-stationary distributions; |
| 4) **Guard rail**: if after scaling with "task-level std", any value in a group |
| (same index) exceeds [-guard_abs_max, +guard_abs_max], |
| then that group falls back to "group std" scaling (only that group is affected). |
| Returns: |
| advantages: (bs, response_length) |
| returns: (bs, response_length) |
| """ |
| |
| device = token_level_rewards.device |
| dtype = token_level_rewards.dtype |
|
|
| |
| |
| _EMA_STD_TRACKER.decay = float(ema_decay) |
| _EMA_STD_TRACKER.min_std = float(min_std) |
|
|
| |
| |
| def _to_list(x): |
| if x is None: |
| return None |
| if torch.is_tensor(x): |
| x = x.tolist() |
| else: |
| try: |
| import numpy as _np |
| if isinstance(x, _np.ndarray): |
| x = x.tolist() |
| except Exception: |
| pass |
| if not isinstance(x, (list, tuple)): |
| x = list(x) |
| return x |
|
|
| index_list = _to_list(index) |
| problem_type_list = _to_list(problem_type) |
| data_type_list = _to_list(data_type) |
|
|
| |
| index_keys = [str(g) for g in index_list] |
|
|
| |
| |
| scores = token_level_rewards.sum(dim=-1) |
|
|
| bsz = scores.shape[0] |
|
|
| |
| |
| |
| def _task_key_of(sample_problem_type: str, sample_data_type: str | None) -> str: |
| if sample_problem_type == "segmentation": |
| dt = (sample_data_type or "").lower() |
| if dt in ("video", "image"): |
| return f"segmentation/{dt}" |
| return sample_problem_type |
|
|
| task_to_pos: dict[str, list[int]] = {} |
| for i in range(bsz): |
| pt = str(problem_type_list[i]) if problem_type_list is not None else "" |
| dt = None |
| if pt == "segmentation" and data_type_list is not None: |
| dt = str(data_type_list[i]) if i < len(data_type_list) else None |
| key = _task_key_of(pt, dt) |
| task_to_pos.setdefault(key, []).append(i) |
|
|
| |
| centered = scores.clone() |
| gid_to_pos: dict[str, list[int]] = {} |
| for i, gid in enumerate(index_keys): |
| gid_to_pos.setdefault(gid, []).append(i) |
|
|
| if use_group_mean_centering: |
| for gid, pos_list in gid_to_pos.items(): |
| |
| assert len(pos_list) > 1, "EMA-GRPO requires rollout.n > 1 per group (same index)." |
| g = scores[pos_list] |
| g_mean = g.mean() |
| centered[pos_list] = g - g_mean |
| else: |
| |
| centered = centered - 0.0 |
|
|
| |
| |
| |
| |
| |
|
|
| |
| for key, pos_list in task_to_pos.items(): |
| _EMA_STD_TRACKER.update_with_batch_scores(key, scores[pos_list]) |
|
|
| |
| _EMA_STD_TRACKER.log_means_and_stds( |
| keys=None, |
| tag_means="EMA_MEAN(after_update)", |
| tag_stds="EMA_STD(after_update)", |
| ) |
|
|
|
|
| |
| scaled = centered.clone() |
|
|
| |
| from collections import defaultdict as _dd |
| for key, task_pos in task_to_pos.items(): |
| task_std = _EMA_STD_TRACKER.get_std(key) |
|
|
| |
| group_to_pos: dict[str, list[int]] = _dd(list) |
| for i in task_pos: |
| group_to_pos[index_keys[i]].append(i) |
|
|
| |
| |
| for gid, gpos in group_to_pos.items(): |
| |
| tmp = centered[gpos] / (task_std + eps) |
|
|
| |
| if torch.any(torch.abs(tmp) > guard_abs_max): |
| |
| g_scores = scores[gpos].float() |
| g_std = torch.std(g_scores, unbiased=False).item() |
| scaled[gpos] = centered[gpos] / (g_std + eps) |
| else: |
| |
| scaled[gpos] = tmp |
|
|
| |
| |
| |
| if response_mask.device != device: |
| response_mask = response_mask.to(device) |
| if not torch.is_floating_point(response_mask): |
| response_mask = response_mask.to(dtype) |
|
|
| returns = scaled.to(dtype).unsqueeze(-1) * response_mask |
| advantages = returns |
|
|
| return advantages, returns |
|
|
|
|
|
|
|
|
|
|
| @register_adv_estimator(AdvantageEstimator.RLOO) |
| def compute_rloo_outcome_advantage( |
| token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, **kwargs |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| index: `(torch.Tensor)` |
| shape: (bs,) |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| |
| """ |
| scores = token_level_rewards.sum(dim=-1) |
|
|
| id2score = defaultdict(list) |
| id2sum = {} |
| bsz = scores.shape[0] |
| for i in range(bsz): |
| id2score[index[i]].append(scores[i]) |
|
|
| for idx in id2score: |
| id2sum[idx] = torch.sum(torch.tensor(id2score[idx])) |
|
|
| for i in range(bsz): |
| sample_num = len(id2score[index[i]]) |
| assert sample_num > 1, "RLOO needs rollout.n > 1." |
| baseline = (id2sum[index[i]] - scores[i]) / (sample_num - 1) |
| scores[i] = scores[i] - baseline |
|
|
| returns = scores.unsqueeze(-1) * response_mask |
| return returns, returns |
|
|
|
|
| @register_adv_estimator(AdvantageEstimator.REINFORCE_PLUS_PLUS) |
| def compute_reinforce_plus_plus_outcome_advantage( |
| token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor, **kwargs |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for REINFORCE++. |
| This implementation is based on the paper: https://arxiv.org/abs/2501.03262 |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| |
| """ |
| returns = torch.zeros_like(token_level_rewards) |
| running_return = 0 |
| for t in reversed(range(token_level_rewards.shape[1])): |
| running_return = token_level_rewards[:, t] + gamma * running_return |
| returns[:, t] = running_return |
| |
| running_return = running_return * response_mask[:, t] |
|
|
| advantages = VF.masked_whiten(returns, response_mask) |
| return advantages, returns |
|
|
|
|
| @register_adv_estimator(AdvantageEstimator.REMAX) |
| def compute_remax_outcome_advantage( |
| token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor, **kwargs |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for ReMax, operating only on Outcome reward |
| This implementation is based on the paper: https://arxiv.org/abs/2310.10505 |
| |
| (with only one scalar reward for each response). |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| reward_baselines: `(torch.Tensor)` |
| shape: (bs,) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| |
| """ |
| scores = token_level_rewards.sum(dim=-1) - reward_baselines |
| returns = scores.unsqueeze(-1) * response_mask |
| return returns, returns |
|
|
|
|
| def compute_rewards( |
| token_level_scores: torch.Tensor, |
| log_probs: torch.Tensor, |
| ref_log_probs: torch.Tensor, |
| kl_ratio: float, |
| ) -> torch.Tensor: |
| kl = log_probs - ref_log_probs |
| return token_level_scores - kl * kl_ratio |
|
|
|
|
| def average_loss( |
| values: torch.Tensor, mask: torch.Tensor, mode: Literal["token", "seq"], eps: float = 1e-8 |
| ) -> torch.Tensor: |
| """Average the policy loss. |
| |
| Args: |
| values: `(torch.Tensor)` |
| shape: (bs, response_length) |
| mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| mode: `(Literal["token", "seq"])` |
| "token": average the loss in the whole batch |
| "seq": average the loss in each sequence then average the mean of the means |
| eps: `(float)` |
| epsilon value |
| |
| Returns: |
| loss: `a scalar torch.Tensor` |
| """ |
| if mode == "token": |
| return VF.masked_mean(values, mask, eps=eps) |
| elif mode == "seq": |
| return ((values * mask).sum(-1) / (mask.sum(-1) + eps)).mean() |
| else: |
| raise NotImplementedError(f"Unknown mode: {mode}.") |
|
|
|
|
| def compute_policy_loss( |
| old_log_probs: torch.Tensor, |
| log_probs: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| clip_ratio_low: float, |
| clip_ratio_high: float, |
| clip_ratio_dual: float, |
| loss_avg_mode: Literal["token", "seq"], |
| ) -> tuple[torch.Tensor, dict[str, float]]: |
| """Compute the clipped policy objective and related metrics for PPO. |
| |
| Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568 |
| |
| Args: |
| old_log_prob: `(torch.Tensor)` |
| shape: (bs, response_length) |
| log_prob: `(torch.Tensor)` |
| shape: (bs, response_length) |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| clip_ratio_low: (float) |
| The lower clip range used in PPO. See https://arxiv.org/abs/1707.06347 |
| clip_ratio_high: (float) |
| The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476 |
| clip_ratio_dual: (float) |
| The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729 |
| loss_avg_mode: (Literal["token", "seq"]) |
| "token": average the loss in the whole batch |
| "seq": average the loss in each sequence then average the mean of the means |
| |
| Returns: |
| pg_loss: `a scalar torch.Tensor` |
| policy gradient loss computed via PPO |
| pg_clipfrac_higher: (float) |
| a float number indicating the fraction of policy gradient loss being clipped to a higher value |
| pg_clipfrac_lower: (float) |
| a float number indicating the fraction of policy gradient loss being clipped to a lower value |
| ppo_kl: (float) |
| a float number indicating the mean KL divergence between the old policy and the new policy |
| entropy_loss: (float) |
| a float number indicating the mean entropy loss |
| |
| """ |
| negative_approx_kl = log_probs - old_log_probs |
| |
| negative_approx_kl = torch.clamp(negative_approx_kl, -20.0, 20.0) |
| ratio = torch.exp(negative_approx_kl) |
| |
| |
| clipped_ratio = torch.exp( |
| torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high)) |
| ) |
|
|
| |
| metrics = {"ppo_kl": -negative_approx_kl} |
| |
| metrics["entropy_loss"] = average_loss(-log_probs, response_mask, mode=loss_avg_mode) |
|
|
| pg_loss = -advantages * ratio |
| pg_loss2 = -advantages * clipped_ratio |
| pg_loss3 = -advantages * clip_ratio_dual |
|
|
| clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2) |
| metrics["pg_clipfrac_higher"] = (pg_loss < pg_loss2).float() |
| clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3) |
| final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher) |
| metrics["pg_clipfrac_lower"] = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float() |
|
|
| final_pg_loss = average_loss(final_pg_loss, response_mask, mode=loss_avg_mode) |
| metrics = {k: VF.masked_mean(v, response_mask).detach().item() for k, v in metrics.items()} |
| return final_pg_loss, metrics |
|
|
|
|
| def compute_value_loss( |
| vpreds: torch.Tensor, |
| returns: torch.Tensor, |
| values: torch.Tensor, |
| response_mask: torch.Tensor, |
| cliprange_value: float, |
| loss_avg_mode: Literal["token", "seq"], |
| ) -> tuple[torch.Tensor, dict[str, float]]: |
| """Compute the value loss. |
| |
| Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556 |
| |
| Args: |
| vpreds (`torch.FloatTensor`): |
| Predicted values of the value head, shape (`batch_size`, `response_length`) |
| returns: (`torch.FloatTensor`): |
| Ground truth returns, shape (`batch_size`, `response_length`) |
| values (`torch.FloatTensor`): |
| Old values of value head, shape (`batch_size`, `response_length`) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| cliprange_value: (float) |
| The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347 |
| loss_avg_mode: (Literal["token", "seq"]) |
| "token": average the loss in the whole batch |
| "seq": average the loss in each sequence then average the mean of the means |
| |
| Returns: |
| vf_loss: a scalar (`torch.FloatTensor`): |
| value function loss |
| vf_clipfrac: a float |
| The ratio of vf being clipped |
| vpred_mean: a float |
| The mean of predicted values |
| |
| """ |
| vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value) |
| vf_loss1 = torch.square(vpreds - returns) |
| vf_loss2 = torch.square(vpredclipped - returns) |
| clipped_vf_losses = torch.max(vf_loss1, vf_loss2) |
| vf_loss = 0.5 * average_loss(clipped_vf_losses, response_mask, mode=loss_avg_mode) |
| metrics = { |
| "vf_clipfrac": VF.masked_mean((vf_loss1 < vf_loss2).float(), response_mask).detach().item(), |
| "vpred_mean": VF.masked_mean(vpreds, response_mask).detach().item(), |
| } |
| return vf_loss, metrics |
|
|
|
|
| def compute_kl( |
| log_probs: torch.FloatTensor, |
| ref_log_probs: torch.FloatTensor, |
| kl_penalty: Literal["kl", "abs", "mse", "low_var_kl", "full"], |
| ) -> torch.Tensor: |
| """Compute KL divergence given log_probs and ref_log_probs. |
| |
| Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150 |
| |
| Args: |
| log_probs: torch.Tensor |
| ref_log_probs: torch.Tensor |
| kl_penalty: str ("kl", "abs", "mse", "low_var_kl", "full") |
| |
| Returns: |
| kl_div: torch.Tensor |
| |
| """ |
| log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float() |
| if kl_penalty == "kl": |
| return log_probs - ref_log_probs |
|
|
| if kl_penalty == "abs": |
| return (log_probs - ref_log_probs).abs() |
|
|
| if kl_penalty == "mse": |
| return 0.5 * (log_probs - ref_log_probs).square() |
|
|
| |
| |
| if kl_penalty == "low_var_kl": |
| |
| kl = (ref_log_probs - log_probs).clamp(-20.0, 20.0) |
| kld = (kl.exp() - kl - 1).contiguous() |
| return torch.clamp(kld, min=-10.0, max=10.0) |
|
|
| if kl_penalty == "full": |
| return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1) |
|
|
| raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.") |
|
|