Spaces:
Sleeping
Sleeping
| """ | |
| Value Network (Critic) for PPO. | |
| ValueHead wraps a frozen copy of the base language model backbone and | |
| appends a small MLP to regress a scalar value V(s_t) β β. | |
| Design notes | |
| ------------ | |
| - The backbone is loaded once with bfloat16 to fit on GPU. | |
| - Only the MLP head (value_head) is updated during training; the | |
| backbone can optionally be unfrozen for fine-grained critic learning. | |
| - The forward pass returns a 1-D tensor of shape (batch_size,) so the | |
| caller can do .item() for single inputs. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any, Optional | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoConfig, AutoModel | |
| from src.utils.attn_backend import select_attn_implementation | |
| logger = logging.getLogger(__name__) | |
| class ValueHead(nn.Module): | |
| """ | |
| Critic network V_Ο(s). | |
| Architecture | |
| ------------ | |
| backbone (LM encoder, frozen by default) | |
| β last-token hidden state [hidden_size] | |
| Linear(hidden_size, 256) + ReLU | |
| β | |
| Linear(256, 1) | |
| β squeeze β scalar V(s) | |
| Args: | |
| base_model_path : HuggingFace model id or local checkpoint path. | |
| freeze_backbone : If True, backbone weights are not updated. | |
| Defaults to True (only head is trained). | |
| hidden_size : Override backbone hidden size (auto-detected | |
| from config when None). | |
| """ | |
| def __init__( | |
| self, | |
| base_model_path: str, | |
| freeze_backbone: bool = True, | |
| hidden_size: Optional[int] = None, | |
| model_device_map: Optional[Any] = "auto", | |
| max_memory: Optional[dict] = None, | |
| ) -> None: | |
| super().__init__() | |
| logger.info(f"Loading ValueHead backbone from {base_model_path}") | |
| config = AutoConfig.from_pretrained( | |
| base_model_path, trust_remote_code=True | |
| ) | |
| h = hidden_size or config.hidden_size | |
| # Always load on CPU first to avoid 90% GPU allocation | |
| # The caller will move to GPU if needed | |
| load_kwargs = { | |
| "torch_dtype": torch.bfloat16, | |
| "device_map": model_device_map, | |
| "low_cpu_mem_usage": True, | |
| "trust_remote_code": True, | |
| "attn_implementation": select_attn_implementation(), | |
| } | |
| self.backbone = AutoModel.from_pretrained( | |
| base_model_path, | |
| **load_kwargs, | |
| ) | |
| if freeze_backbone: | |
| for param in self.backbone.parameters(): | |
| param.requires_grad_(False) | |
| logger.info("Backbone frozen; only ValueHead MLP will be trained.") | |
| self.value_head = nn.Sequential( | |
| nn.Linear(h, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, 1), | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Forward | |
| # ------------------------------------------------------------------ | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute V(s) for a batch of states. | |
| Args: | |
| input_ids : [batch, seq_len] | |
| attention_mask : [batch, seq_len] (ones if None) | |
| Returns: | |
| values : [batch] β scalar value estimate per sequence | |
| """ | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| outputs = self.backbone( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| # Last *non-pad* token (right-padded batches: last valid index per row) | |
| last_hidden = outputs.last_hidden_state # [B, T, H] | |
| last_idx = attention_mask.long().sum(dim=1) - 1 | |
| last_idx = last_idx.clamp(min=0) | |
| b = torch.arange(last_hidden.size(0), device=last_hidden.device) | |
| cls_hidden = last_hidden[b, last_idx].to(self.value_head[0].weight.dtype) | |
| values = self.value_head(cls_hidden).squeeze(-1) # [B] | |
| return values | |
| def values_at_positions( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute V(s_t) for many states in a SINGLE backbone forward pass. | |
| The naive rollout loop calls ``self.value(...)`` once per generated | |
| token, which does one full backbone forward over the growing | |
| sequence each step β that's O(T^2) work for T tokens. This helper | |
| lets the caller run the backbone exactly once on the full | |
| trajectory and then pluck hidden states at the positions that | |
| correspond to each state s_t. | |
| For a trajectory with prompt length P and T generated tokens, | |
| state s_t (= prompt + generated[:t], t=0..T-1) is a "last token" | |
| at position P + t - 1 in the full sequence, so callers pass | |
| ``positions = torch.arange(P - 1, P + T - 1)``. | |
| Args: | |
| input_ids: | |
| [1, L] full trajectory (prompt + generated). A single | |
| un-padded sequence β callers that need batched different- | |
| length trajectories should loop over them (cheap because | |
| each call is O(L), not O(L^2)). | |
| positions: | |
| [N] long tensor of indices into the L-axis. Hidden states | |
| at these positions will be fed through the value MLP. | |
| attention_mask: | |
| Optional [1, L] mask. Defaults to all-ones. | |
| Returns: | |
| values: [N] scalar value estimates, one per requested position, | |
| on the same device as ``input_ids`` and already in float32 | |
| (so callers can safely ``.tolist()`` them for the buffer). | |
| """ | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| outputs = self.backbone( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| hidden = outputs.last_hidden_state # [1, L, H] | |
| positions = positions.to(device=hidden.device, dtype=torch.long) | |
| # Clamp just in case the caller requests an out-of-range position | |
| # (e.g. T=0 edge cases). clamp is a no-op for valid indices. | |
| positions = positions.clamp(min=0, max=hidden.size(1) - 1) | |
| # Gather β [N, H]. Cast to the value_head's weight dtype so | |
| # bf16 backbone + fp32 head works regardless of how torch | |
| # autocast is configured on the caller side. | |
| gathered = hidden[0, positions].to(self.value_head[0].weight.dtype) | |
| values = self.value_head(gathered).squeeze(-1).float() # [N] | |
| return values | |