AxiomForgeAI / src /rl /value_network.py
jampuramprem's picture
Initial Space deployment
ec4ae03
"""
Value Network (Critic) for PPO.
ValueHead wraps a frozen copy of the base language model backbone and
appends a small MLP to regress a scalar value V(s_t) ∈ ℝ.
Design notes
------------
- The backbone is loaded once with bfloat16 to fit on GPU.
- Only the MLP head (value_head) is updated during training; the
backbone can optionally be unfrozen for fine-grained critic learning.
- The forward pass returns a 1-D tensor of shape (batch_size,) so the
caller can do .item() for single inputs.
"""
from __future__ import annotations
import logging
from typing import Any, Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
from src.utils.attn_backend import select_attn_implementation
logger = logging.getLogger(__name__)
class ValueHead(nn.Module):
"""
Critic network V_Ο†(s).
Architecture
------------
backbone (LM encoder, frozen by default)
↓ last-token hidden state [hidden_size]
Linear(hidden_size, 256) + ReLU
↓
Linear(256, 1)
↓ squeeze β†’ scalar V(s)
Args:
base_model_path : HuggingFace model id or local checkpoint path.
freeze_backbone : If True, backbone weights are not updated.
Defaults to True (only head is trained).
hidden_size : Override backbone hidden size (auto-detected
from config when None).
"""
def __init__(
self,
base_model_path: str,
freeze_backbone: bool = True,
hidden_size: Optional[int] = None,
model_device_map: Optional[Any] = "auto",
max_memory: Optional[dict] = None,
) -> None:
super().__init__()
logger.info(f"Loading ValueHead backbone from {base_model_path}")
config = AutoConfig.from_pretrained(
base_model_path, trust_remote_code=True
)
h = hidden_size or config.hidden_size
# Always load on CPU first to avoid 90% GPU allocation
# The caller will move to GPU if needed
load_kwargs = {
"torch_dtype": torch.bfloat16,
"device_map": model_device_map,
"low_cpu_mem_usage": True,
"trust_remote_code": True,
"attn_implementation": select_attn_implementation(),
}
self.backbone = AutoModel.from_pretrained(
base_model_path,
**load_kwargs,
)
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad_(False)
logger.info("Backbone frozen; only ValueHead MLP will be trained.")
self.value_head = nn.Sequential(
nn.Linear(h, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute V(s) for a batch of states.
Args:
input_ids : [batch, seq_len]
attention_mask : [batch, seq_len] (ones if None)
Returns:
values : [batch] β€” scalar value estimate per sequence
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
)
# Last *non-pad* token (right-padded batches: last valid index per row)
last_hidden = outputs.last_hidden_state # [B, T, H]
last_idx = attention_mask.long().sum(dim=1) - 1
last_idx = last_idx.clamp(min=0)
b = torch.arange(last_hidden.size(0), device=last_hidden.device)
cls_hidden = last_hidden[b, last_idx].to(self.value_head[0].weight.dtype)
values = self.value_head(cls_hidden).squeeze(-1) # [B]
return values
@torch.no_grad()
def values_at_positions(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute V(s_t) for many states in a SINGLE backbone forward pass.
The naive rollout loop calls ``self.value(...)`` once per generated
token, which does one full backbone forward over the growing
sequence each step β€” that's O(T^2) work for T tokens. This helper
lets the caller run the backbone exactly once on the full
trajectory and then pluck hidden states at the positions that
correspond to each state s_t.
For a trajectory with prompt length P and T generated tokens,
state s_t (= prompt + generated[:t], t=0..T-1) is a "last token"
at position P + t - 1 in the full sequence, so callers pass
``positions = torch.arange(P - 1, P + T - 1)``.
Args:
input_ids:
[1, L] full trajectory (prompt + generated). A single
un-padded sequence β€” callers that need batched different-
length trajectories should loop over them (cheap because
each call is O(L), not O(L^2)).
positions:
[N] long tensor of indices into the L-axis. Hidden states
at these positions will be fed through the value MLP.
attention_mask:
Optional [1, L] mask. Defaults to all-ones.
Returns:
values: [N] scalar value estimates, one per requested position,
on the same device as ``input_ids`` and already in float32
(so callers can safely ``.tolist()`` them for the buffer).
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
)
hidden = outputs.last_hidden_state # [1, L, H]
positions = positions.to(device=hidden.device, dtype=torch.long)
# Clamp just in case the caller requests an out-of-range position
# (e.g. T=0 edge cases). clamp is a no-op for valid indices.
positions = positions.clamp(min=0, max=hidden.size(1) - 1)
# Gather β†’ [N, H]. Cast to the value_head's weight dtype so
# bf16 backbone + fp32 head works regardless of how torch
# autocast is configured on the caller side.
gathered = hidden[0, positions].to(self.value_head[0].weight.dtype)
values = self.value_head(gathered).squeeze(-1).float() # [N]
return values