| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Implement Actor |
| """ |
|
|
| import os |
| from collections import defaultdict |
| from typing import Any, Optional |
|
|
| import torch |
| import torch.distributed as dist |
| from einops import rearrange |
| from ray.experimental.tqdm_ray import tqdm |
| from torch import nn |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
| from ...protocol import DataProto, batch_collate |
| from ...trainer.core_algos import average_loss, compute_kl, compute_policy_loss |
| from ...utils import torch_functional as VF |
| from ...utils.py_functional import append_to_dict |
| from ...utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch |
| from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs |
| from .base import BasePPOActor |
| from .config import ActorConfig |
|
|
|
|
| try: |
| from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input |
| except ImportError: |
| pass |
|
|
|
|
| __all__ = ["DataParallelPPOActor"] |
|
|
|
|
| class DataParallelPPOActor(BasePPOActor): |
| def __init__( |
| self, |
| config: ActorConfig, |
| actor_module: nn.Module, |
| actor_optimizer: Optional[torch.optim.Optimizer] = None, |
| ): |
| """ |
| When optimizer is None, it is Reference Policy |
| """ |
| super().__init__(config) |
| self.rank = int(os.getenv("RANK", "0")) |
| self.world_size = int(os.getenv("WORLD_SIZE", "1")) |
| self.actor_module = actor_module |
| self.actor_optimizer = actor_optimizer |
| if config.use_torch_compile: |
| self.log_probs_from_logits = torch.compile(VF.log_probs_from_logits, dynamic=True) |
| else: |
| self.log_probs_from_logits = VF.log_probs_from_logits |
|
|
| def _forward_micro_batch(self, micro_batch: dict[str, torch.Tensor], temperature: float) -> torch.Tensor: |
| """ |
| Returns: |
| log_probs: # (bs, response_len) |
| """ |
| input_ids = micro_batch["input_ids"] |
| batch_size, seqlen = input_ids.shape |
| attention_mask = micro_batch["attention_mask"] |
| position_ids = micro_batch["position_ids"] |
| responses = micro_batch["responses"] |
| response_length = responses.size(-1) |
| if position_ids.dim() == 3: |
| position_ids = position_ids.transpose(0, 1) |
|
|
| multi_modal_inputs = defaultdict(list) |
| if "multi_modal_inputs" in micro_batch: |
| multi_modal_inputs = batch_collate(micro_batch["multi_modal_inputs"]) |
| multi_modal_inputs = {key: torch.cat(value, dim=0) for key, value in multi_modal_inputs.items()} |
| else: |
| multi_modal_inputs = {} |
|
|
| if self.config.padding_free: |
| input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) |
| input_ids_rmpad = input_ids_rmpad.transpose(0, 1) |
|
|
| |
| if position_ids.dim() == 3: |
| position_ids_rmpad = ( |
| index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) |
| .transpose(0, 1) |
| .unsqueeze(1) |
| ) |
| else: |
| position_ids_rmpad = index_first_axis( |
| rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices |
| ).transpose(0, 1) |
|
|
| |
| input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) |
|
|
| |
| if self.config.ulysses_size > 1: |
| input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( |
| input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_size |
| ) |
| input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( |
| input_ids_rmpad_rolled, None, self.config.ulysses_size |
| ) |
|
|
| input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) |
|
|
| |
| output = self.actor_module( |
| input_ids=input_ids_rmpad, |
| attention_mask=None, |
| position_ids=position_ids_rmpad, |
| **multi_modal_inputs, |
| use_cache=False, |
| ) |
| logits_rmpad = output.logits.squeeze(0) |
| logits_rmpad.div_(temperature) |
| |
| log_probs = self.log_probs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) |
|
|
| |
| if self.config.ulysses_size > 1: |
| |
| log_probs = gather_outputs_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size) |
|
|
| |
| full_log_probs = pad_input( |
| hidden_states=log_probs.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen |
| ) |
| log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] |
| else: |
| output = self.actor_module( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| **multi_modal_inputs, |
| use_cache=False, |
| ) |
| logits: torch.Tensor = output.logits |
| logits.div_(temperature) |
| logits = logits[:, -response_length - 1 : -1, :] |
| log_probs = self.log_probs_from_logits(logits, responses) |
|
|
| return log_probs |
|
|
| def _optimizer_step(self) -> torch.Tensor: |
| if isinstance(self.actor_module, FSDP): |
| grad_norm = self.actor_module.clip_grad_norm_(self.config.max_grad_norm) |
| else: |
| grad_norm = nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.max_grad_norm) |
|
|
| if not torch.isfinite(grad_norm): |
| print("Gradient norm is not finite. Skip update.") |
| else: |
| self.actor_optimizer.step() |
|
|
| self.actor_optimizer.zero_grad() |
| return grad_norm |
|
|
| @torch.no_grad() |
| def compute_log_prob(self, data: DataProto) -> torch.Tensor: |
| """Compute the log probability of the responses given input_ids, attention_mask and position_ids |
| |
| Args: |
| data (DataProto): a DataProto containing keys |
| |
| ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the |
| concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. |
| |
| ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. |
| |
| ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. |
| |
| ``responses``: tensor of shape [batch_size, response_length]. torch.int64. |
| |
| Returns: |
| torch.Tensor: the log_prob tensor |
| """ |
| self.actor_module.eval() |
|
|
| temperature = data.meta_info["temperature"] |
| select_keys = ["input_ids", "attention_mask", "position_ids", "responses"] |
| non_tensor_select_keys = ["multi_modal_inputs"] |
|
|
| data = data.select(select_keys, non_tensor_select_keys) |
| if self.config.dynamic_batching: |
| max_token_len = self.config.micro_batch_size_per_device_for_experience * data.batch["input_ids"].size(-1) |
| micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) |
| else: |
| micro_batches = data.split(self.config.micro_batch_size_per_device_for_experience) |
|
|
| log_probs_lst = [] |
| if self.rank == 0: |
| micro_batches = tqdm(micro_batches, desc="Compute log probs", position=1) |
|
|
| for micro_batch in micro_batches: |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} |
| log_probs = self._forward_micro_batch(model_inputs, temperature=temperature) |
| log_probs_lst.append(log_probs) |
|
|
| log_probs = torch.concat(log_probs_lst, dim=0) |
|
|
| if self.config.dynamic_batching: |
| log_probs = restore_dynamic_batch(log_probs, batch_idx_list) |
|
|
| return log_probs |
|
|
| def update_policy(self, data: DataProto) -> dict[str, Any]: |
| self.actor_module.train() |
|
|
| temperature = data.meta_info["temperature"] |
| select_keys = ["input_ids", "attention_mask", "position_ids", "responses", "response_mask"] |
| select_keys.extend(["old_log_probs", "ref_log_probs", "advantages"]) |
| non_tensor_select_keys = ["multi_modal_inputs"] |
|
|
| |
| |
| mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device) |
|
|
| metrics = defaultdict(list) |
| for _ in range(self.config.ppo_epochs): |
| if self.rank == 0: |
| mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=1) |
|
|
| for mini_batch in mini_batches: |
| total_response_tokens = torch.sum(mini_batch.batch["response_mask"]) |
| dist.all_reduce(total_response_tokens, op=dist.ReduceOp.SUM) |
|
|
| if self.config.dynamic_batching: |
| max_input_len = mini_batch.batch["input_ids"].size(-1) |
| max_token_len = self.config.micro_batch_size_per_device_for_update * max_input_len |
| micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) |
| else: |
| micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update) |
|
|
| if self.rank == 0: |
| micro_batches = tqdm(micro_batches, desc="Update policy", position=2) |
|
|
| for micro_batch in micro_batches: |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} |
| response_mask = model_inputs["response_mask"] |
| old_log_probs = model_inputs["old_log_probs"] |
| advantages = model_inputs["advantages"] |
|
|
| |
| log_probs = self._forward_micro_batch(model_inputs, temperature=temperature) |
|
|
| pg_loss, pg_metrics = compute_policy_loss( |
| old_log_probs=old_log_probs, |
| log_probs=log_probs, |
| advantages=advantages, |
| response_mask=response_mask, |
| clip_ratio_low=self.config.clip_ratio_low, |
| clip_ratio_high=self.config.clip_ratio_high, |
| clip_ratio_dual=self.config.clip_ratio_dual, |
| loss_avg_mode=self.config.loss_avg_mode, |
| ) |
| if self.config.use_kl_loss and "ref_log_probs" in model_inputs: |
| ref_log_probs = model_inputs["ref_log_probs"] |
| |
| kld = compute_kl( |
| log_probs=log_probs, |
| ref_log_probs=ref_log_probs, |
| kl_penalty=self.config.kl_penalty, |
| ) |
| kl_loss = average_loss(kld, response_mask, mode=self.config.loss_avg_mode) |
| loss = pg_loss + kl_loss * self.config.kl_coef |
| metrics["actor/kl_loss"] = kl_loss.detach().item() |
| metrics["actor/kl_coef"] = self.config.kl_coef |
| else: |
| loss = pg_loss |
|
|
| loss = loss * torch.sum(response_mask) * self.world_size / total_response_tokens |
| loss.backward() |
|
|
| batch_metrics = { |
| "actor/pg_loss": pg_loss.detach().item(), |
| "actor/pg_clipfrac_higher": pg_metrics["pg_clipfrac_higher"], |
| "actor/pg_clipfrac_lower": pg_metrics["pg_clipfrac_lower"], |
| "actor/entropy_loss": pg_metrics["entropy_loss"], |
| "actor/ppo_kl": pg_metrics["ppo_kl"], |
| } |
| append_to_dict(metrics, batch_metrics) |
|
|
| grad_norm = self._optimizer_step() |
| append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()}) |
|
|
| return metrics |
|
|