Spaces:
Sleeping
Sleeping
File size: 6,737 Bytes
ec4ae03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """
Value Network (Critic) for PPO.
ValueHead wraps a frozen copy of the base language model backbone and
appends a small MLP to regress a scalar value V(s_t) β β.
Design notes
------------
- The backbone is loaded once with bfloat16 to fit on GPU.
- Only the MLP head (value_head) is updated during training; the
backbone can optionally be unfrozen for fine-grained critic learning.
- The forward pass returns a 1-D tensor of shape (batch_size,) so the
caller can do .item() for single inputs.
"""
from __future__ import annotations
import logging
from typing import Any, Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
from src.utils.attn_backend import select_attn_implementation
logger = logging.getLogger(__name__)
class ValueHead(nn.Module):
"""
Critic network V_Ο(s).
Architecture
------------
backbone (LM encoder, frozen by default)
β last-token hidden state [hidden_size]
Linear(hidden_size, 256) + ReLU
β
Linear(256, 1)
β squeeze β scalar V(s)
Args:
base_model_path : HuggingFace model id or local checkpoint path.
freeze_backbone : If True, backbone weights are not updated.
Defaults to True (only head is trained).
hidden_size : Override backbone hidden size (auto-detected
from config when None).
"""
def __init__(
self,
base_model_path: str,
freeze_backbone: bool = True,
hidden_size: Optional[int] = None,
model_device_map: Optional[Any] = "auto",
max_memory: Optional[dict] = None,
) -> None:
super().__init__()
logger.info(f"Loading ValueHead backbone from {base_model_path}")
config = AutoConfig.from_pretrained(
base_model_path, trust_remote_code=True
)
h = hidden_size or config.hidden_size
# Always load on CPU first to avoid 90% GPU allocation
# The caller will move to GPU if needed
load_kwargs = {
"torch_dtype": torch.bfloat16,
"device_map": model_device_map,
"low_cpu_mem_usage": True,
"trust_remote_code": True,
"attn_implementation": select_attn_implementation(),
}
self.backbone = AutoModel.from_pretrained(
base_model_path,
**load_kwargs,
)
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad_(False)
logger.info("Backbone frozen; only ValueHead MLP will be trained.")
self.value_head = nn.Sequential(
nn.Linear(h, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute V(s) for a batch of states.
Args:
input_ids : [batch, seq_len]
attention_mask : [batch, seq_len] (ones if None)
Returns:
values : [batch] β scalar value estimate per sequence
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
)
# Last *non-pad* token (right-padded batches: last valid index per row)
last_hidden = outputs.last_hidden_state # [B, T, H]
last_idx = attention_mask.long().sum(dim=1) - 1
last_idx = last_idx.clamp(min=0)
b = torch.arange(last_hidden.size(0), device=last_hidden.device)
cls_hidden = last_hidden[b, last_idx].to(self.value_head[0].weight.dtype)
values = self.value_head(cls_hidden).squeeze(-1) # [B]
return values
@torch.no_grad()
def values_at_positions(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute V(s_t) for many states in a SINGLE backbone forward pass.
The naive rollout loop calls ``self.value(...)`` once per generated
token, which does one full backbone forward over the growing
sequence each step β that's O(T^2) work for T tokens. This helper
lets the caller run the backbone exactly once on the full
trajectory and then pluck hidden states at the positions that
correspond to each state s_t.
For a trajectory with prompt length P and T generated tokens,
state s_t (= prompt + generated[:t], t=0..T-1) is a "last token"
at position P + t - 1 in the full sequence, so callers pass
``positions = torch.arange(P - 1, P + T - 1)``.
Args:
input_ids:
[1, L] full trajectory (prompt + generated). A single
un-padded sequence β callers that need batched different-
length trajectories should loop over them (cheap because
each call is O(L), not O(L^2)).
positions:
[N] long tensor of indices into the L-axis. Hidden states
at these positions will be fed through the value MLP.
attention_mask:
Optional [1, L] mask. Defaults to all-ones.
Returns:
values: [N] scalar value estimates, one per requested position,
on the same device as ``input_ids`` and already in float32
(so callers can safely ``.tolist()`` them for the buffer).
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
)
hidden = outputs.last_hidden_state # [1, L, H]
positions = positions.to(device=hidden.device, dtype=torch.long)
# Clamp just in case the caller requests an out-of-range position
# (e.g. T=0 edge cases). clamp is a no-op for valid indices.
positions = positions.clamp(min=0, max=hidden.size(1) - 1)
# Gather β [N, H]. Cast to the value_head's weight dtype so
# bf16 backbone + fp32 head works regardless of how torch
# autocast is configured on the caller side.
gathered = hidden[0, positions].to(self.value_head[0].weight.dtype)
values = self.value_head(gathered).squeeze(-1).float() # [N]
return values
|