File size: 6,737 Bytes
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Value Network (Critic) for PPO.

ValueHead wraps a frozen copy of the base language model backbone and
appends a small MLP to regress a scalar value V(s_t) ∈ ℝ.

Design notes
------------
- The backbone is loaded once with bfloat16 to fit on GPU.
- Only the MLP head (value_head) is updated during training; the
  backbone can optionally be unfrozen for fine-grained critic learning.
- The forward pass returns a 1-D tensor of shape (batch_size,) so the
  caller can do .item() for single inputs.
"""

from __future__ import annotations

import logging
from typing import Any, Optional

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel

from src.utils.attn_backend import select_attn_implementation

logger = logging.getLogger(__name__)


class ValueHead(nn.Module):
    """
    Critic network V_Ο†(s).

    Architecture
    ------------
    backbone (LM encoder, frozen by default)
        ↓  last-token hidden state  [hidden_size]
    Linear(hidden_size, 256) + ReLU
        ↓
    Linear(256, 1)
        ↓  squeeze  β†’  scalar V(s)

    Args:
        base_model_path : HuggingFace model id or local checkpoint path.
        freeze_backbone : If True, backbone weights are not updated.
                         Defaults to True (only head is trained).
        hidden_size     : Override backbone hidden size (auto-detected
                          from config when None).
    """

    def __init__(
        self,
        base_model_path: str,
        freeze_backbone: bool = True,
        hidden_size: Optional[int] = None,
        model_device_map: Optional[Any] = "auto",
        max_memory: Optional[dict] = None,
    ) -> None:
        super().__init__()

        logger.info(f"Loading ValueHead backbone from {base_model_path}")

        config = AutoConfig.from_pretrained(
            base_model_path, trust_remote_code=True
        )
        h = hidden_size or config.hidden_size

        # Always load on CPU first to avoid 90% GPU allocation
        # The caller will move to GPU if needed
        load_kwargs = {
            "torch_dtype": torch.bfloat16,
            "device_map": model_device_map,
            "low_cpu_mem_usage": True,
            "trust_remote_code": True,
            "attn_implementation": select_attn_implementation(),
        }

        self.backbone = AutoModel.from_pretrained(
            base_model_path,
            **load_kwargs,
        )

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad_(False)
            logger.info("Backbone frozen; only ValueHead MLP will be trained.")

        self.value_head = nn.Sequential(
            nn.Linear(h, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

    # ------------------------------------------------------------------
    # Forward
    # ------------------------------------------------------------------

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute V(s) for a batch of states.

        Args:
            input_ids      : [batch, seq_len]
            attention_mask : [batch, seq_len] (ones if None)

        Returns:
            values : [batch] β€” scalar value estimate per sequence
        """
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        # Last *non-pad* token (right-padded batches: last valid index per row)
        last_hidden = outputs.last_hidden_state  # [B, T, H]
        last_idx = attention_mask.long().sum(dim=1) - 1
        last_idx = last_idx.clamp(min=0)
        b = torch.arange(last_hidden.size(0), device=last_hidden.device)
        cls_hidden = last_hidden[b, last_idx].to(self.value_head[0].weight.dtype)

        values = self.value_head(cls_hidden).squeeze(-1)  # [B]
        return values

    @torch.no_grad()
    def values_at_positions(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute V(s_t) for many states in a SINGLE backbone forward pass.

        The naive rollout loop calls ``self.value(...)`` once per generated
        token, which does one full backbone forward over the growing
        sequence each step β€” that's O(T^2) work for T tokens.  This helper
        lets the caller run the backbone exactly once on the full
        trajectory and then pluck hidden states at the positions that
        correspond to each state s_t.

        For a trajectory with prompt length P and T generated tokens,
        state s_t (= prompt + generated[:t], t=0..T-1) is a "last token"
        at position P + t - 1 in the full sequence, so callers pass
        ``positions = torch.arange(P - 1, P + T - 1)``.

        Args:
            input_ids:
                [1, L] full trajectory (prompt + generated).  A single
                un-padded sequence β€” callers that need batched different-
                length trajectories should loop over them (cheap because
                each call is O(L), not O(L^2)).
            positions:
                [N] long tensor of indices into the L-axis.  Hidden states
                at these positions will be fed through the value MLP.
            attention_mask:
                Optional [1, L] mask.  Defaults to all-ones.

        Returns:
            values: [N] scalar value estimates, one per requested position,
                on the same device as ``input_ids`` and already in float32
                (so callers can safely ``.tolist()`` them for the buffer).
        """
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        hidden = outputs.last_hidden_state  # [1, L, H]

        positions = positions.to(device=hidden.device, dtype=torch.long)
        # Clamp just in case the caller requests an out-of-range position
        # (e.g. T=0 edge cases).  clamp is a no-op for valid indices.
        positions = positions.clamp(min=0, max=hidden.size(1) - 1)

        # Gather β†’ [N, H].  Cast to the value_head's weight dtype so
        # bf16 backbone + fp32 head works regardless of how torch
        # autocast is configured on the caller side.
        gathered = hidden[0, positions].to(self.value_head[0].weight.dtype)
        values = self.value_head(gathered).squeeze(-1).float()  # [N]
        return values