File size: 6,461 Bytes
28767b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
Darwin-9B-NEG — Native Entropy Gating enabled model.

Helper module to attach NEG (Native Entropy Gating) to a Darwin base model.
Provides:
  - NEGHead  : predicts per-token entropy from last hidden state
  - NEGGate  : non-monotonic top-k logit masking (effective in greedy decoding)
  - attach_neg(model, path_or_repo) : monkey-patches forward to apply NEG

See README.md for usage.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file


class NEGHead(nn.Module):
    """NEG-Head: predicts entropy of next-token distribution.

    Input: hidden_state [B, H]
    Output: predicted_entropy [B] (>= 0 via softplus)
    """
    def __init__(self, hidden: int, dropout: float = 0.1):
        super().__init__()
        self.proj_down = nn.Linear(hidden, hidden // 4)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.proj_out = nn.Linear(hidden // 4, 1)

    def forward(self, h):
        x = self.proj_down(h)
        x = self.act(x)
        x = self.dropout(x)
        return F.softplus(self.proj_out(x).squeeze(-1))


class NEGGate(nn.Module):
    """NEG-Gate: top-k logit masking (non-monotonic).

    When predicted_entropy > threshold, restrict logits to top-k candidates.
    This changes argmax (non-monotonic), making NEG effective in greedy decoding.
    """
    def __init__(self, init_threshold: float = 1.175, top_k: int = 20):
        super().__init__()
        self.threshold = nn.Parameter(torch.tensor(init_threshold))
        self.top_k = top_k

    def forward(self, logits, predicted_entropy):
        activate = (predicted_entropy > self.threshold).float().unsqueeze(-1)
        if activate.sum() == 0:
            return logits
        top_k_vals, top_k_idx = logits.topk(self.top_k, dim=-1)
        masked = torch.full_like(logits, float('-inf'))
        masked.scatter_(-1, top_k_idx, top_k_vals)
        return logits * (1 - activate) + masked * activate


def attach_neg(base_model, neg_path_or_repo, hf_token=None):
    """Attach NEG to a loaded base model.

    Args:
        base_model: a HuggingFace AutoModelForCausalLM instance
        neg_path_or_repo: local path or HF repo containing neg_modules.safetensors
        hf_token: optional HF token (for private repos)

    Returns:
        The same model with NEG-Head and NEG-Gate attached and forward() wrapped
        to apply NEG at each generation step.
    """
    # Find neg_modules.safetensors
    neg_file = None
    if os.path.isdir(neg_path_or_repo):
        candidate = os.path.join(neg_path_or_repo, "neg_modules.safetensors")
        if os.path.exists(candidate):
            neg_file = candidate
    if neg_file is None:
        try:
            from huggingface_hub import hf_hub_download
            neg_file = hf_hub_download(
                repo_id=neg_path_or_repo,
                filename="neg_modules.safetensors",
                token=hf_token or os.environ.get("HF_TOKEN"),
            )
        except Exception as e:
            raise FileNotFoundError(
                f"Cannot locate neg_modules.safetensors at {neg_path_or_repo}: {e}"
            )

    # Determine hidden size and device
    hidden_size = getattr(base_model.config, "hidden_size", None)
    if hidden_size is None:
        hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
    if hidden_size is None:
        raise ValueError("Could not determine hidden_size from model config.")

    device = next(base_model.parameters()).device

    # Load state dict
    state = load_file(neg_file)
    head_sd = {k.replace("head.", "", 1): v for k, v in state.items() if k.startswith("head.")}
    gate_sd = {k.replace("gate.", "", 1): v for k, v in state.items() if k.startswith("gate.")}

    # Build and load NEG modules
    head = NEGHead(hidden_size).to(device=device, dtype=torch.float32)
    if head_sd:
        head.load_state_dict(head_sd)
    head.eval()

    # Infer gate params from state
    gate_threshold = gate_sd.get("threshold", torch.tensor(1.175)).item()
    # top_k is not a learnable param; read from metadata if present, else default 20
    top_k = state.get("meta.top_k", torch.tensor(20)).item() if "meta.top_k" in state else 20
    gate = NEGGate(init_threshold=gate_threshold, top_k=int(top_k)).to(
        device=device, dtype=torch.float32
    )
    if gate_sd:
        gate.load_state_dict(gate_sd)
    gate.eval()

    # Attach
    base_model.neg_head = head
    base_model.neg_gate = gate

    # Wrap forward
    original_forward = base_model.forward

    def forward_with_neg(*args, **kwargs):
        # Force hidden states capture
        kwargs["output_hidden_states"] = True
        out = original_forward(*args, **kwargs)
        hidden_states = out.hidden_states
        if hidden_states is None:
            return out
        last_hidden = hidden_states[-1][:, -1].float()
        pred_ent = base_model.neg_head(last_hidden)
        logits = out.logits
        last_logits = logits[:, -1].float()
        guided = base_model.neg_gate(last_logits, pred_ent)
        # Clone and replace last position
        new_logits = logits.clone()
        new_logits[:, -1] = guided.to(logits.dtype)
        out.logits = new_logits
        return out

    base_model.forward = forward_with_neg
    base_model._neg_attached = True

    print(f"[Darwin-NEG] NEG attached successfully.")
    print(f"[Darwin-NEG]   threshold = {gate.threshold.item():.4f}")
    print(f"[Darwin-NEG]   top_k     = {gate.top_k}")
    print(f"[Darwin-NEG]   head params: {sum(p.numel() for p in head.parameters()):,}")
    return base_model


def load_darwin_neg(repo_or_path, torch_dtype=torch.bfloat16, device_map="auto",
                    trust_remote_code=True, hf_token=None, **kwargs):
    """Convenience loader: loads base model + attaches NEG in one call.

    Example:
        from modeling_darwin_neg import load_darwin_neg
        model = load_darwin_neg("FINAL-Bench/Darwin-9B-NEG", hf_token="hf_...")
    """
    from transformers import AutoModelForCausalLM
    token = hf_token or os.environ.get("HF_TOKEN")
    base = AutoModelForCausalLM.from_pretrained(
        repo_or_path,
        torch_dtype=torch_dtype,
        device_map=device_map,
        trust_remote_code=trust_remote_code,
        token=token,
        low_cpu_mem_usage=True,
        **kwargs,
    )
    return attach_neg(base, repo_or_path, hf_token=token)