Darwin-9B-NEG / modeling_darwin_neg.py
SeaWolf-AI's picture
Darwin-9B-NEG v1.0: First Native Entropy Gating model (+11.3%p GPQA Greedy)
aef00eb verified
"""
Darwin-9B-NEG — Native Entropy Gating enabled model.
Helper module to attach NEG (Native Entropy Gating) to a Darwin base model.
Provides:
- NEGHead : predicts per-token entropy from last hidden state
- NEGGate : non-monotonic top-k logit masking (effective in greedy decoding)
- attach_neg(model, path_or_repo) : monkey-patches forward to apply NEG
See README.md for usage.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
class NEGHead(nn.Module):
"""NEG-Head: predicts entropy of next-token distribution.
Input: hidden_state [B, H]
Output: predicted_entropy [B] (>= 0 via softplus)
"""
def __init__(self, hidden: int, dropout: float = 0.1):
super().__init__()
self.proj_down = nn.Linear(hidden, hidden // 4)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.proj_out = nn.Linear(hidden // 4, 1)
def forward(self, h):
x = self.proj_down(h)
x = self.act(x)
x = self.dropout(x)
return F.softplus(self.proj_out(x).squeeze(-1))
class NEGGate(nn.Module):
"""NEG-Gate: top-k logit masking (non-monotonic).
When predicted_entropy > threshold, restrict logits to top-k candidates.
This changes argmax (non-monotonic), making NEG effective in greedy decoding.
"""
def __init__(self, init_threshold: float = 1.175, top_k: int = 20):
super().__init__()
self.threshold = nn.Parameter(torch.tensor(init_threshold))
self.top_k = top_k
def forward(self, logits, predicted_entropy):
activate = (predicted_entropy > self.threshold).float().unsqueeze(-1)
if activate.sum() == 0:
return logits
top_k_vals, top_k_idx = logits.topk(self.top_k, dim=-1)
masked = torch.full_like(logits, float('-inf'))
masked.scatter_(-1, top_k_idx, top_k_vals)
return logits * (1 - activate) + masked * activate
def attach_neg(base_model, neg_path_or_repo, hf_token=None):
"""Attach NEG to a loaded base model.
Args:
base_model: a HuggingFace AutoModelForCausalLM instance
neg_path_or_repo: local path or HF repo containing neg_modules.safetensors
hf_token: optional HF token (for private repos)
Returns:
The same model with NEG-Head and NEG-Gate attached and forward() wrapped
to apply NEG at each generation step.
"""
# Find neg_modules.safetensors
neg_file = None
if os.path.isdir(neg_path_or_repo):
candidate = os.path.join(neg_path_or_repo, "neg_modules.safetensors")
if os.path.exists(candidate):
neg_file = candidate
if neg_file is None:
try:
from huggingface_hub import hf_hub_download
neg_file = hf_hub_download(
repo_id=neg_path_or_repo,
filename="neg_modules.safetensors",
token=hf_token or os.environ.get("HF_TOKEN"),
)
except Exception as e:
raise FileNotFoundError(
f"Cannot locate neg_modules.safetensors at {neg_path_or_repo}: {e}"
)
# Determine hidden size and device
hidden_size = getattr(base_model.config, "hidden_size", None)
if hidden_size is None:
hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
if hidden_size is None:
raise ValueError("Could not determine hidden_size from model config.")
device = next(base_model.parameters()).device
# Load state dict
state = load_file(neg_file)
head_sd = {k.replace("head.", "", 1): v for k, v in state.items() if k.startswith("head.")}
gate_sd = {k.replace("gate.", "", 1): v for k, v in state.items() if k.startswith("gate.")}
# Build and load NEG modules
head = NEGHead(hidden_size).to(device=device, dtype=torch.float32)
if head_sd:
head.load_state_dict(head_sd)
head.eval()
# Infer gate params from state
gate_threshold = gate_sd.get("threshold", torch.tensor(1.175)).item()
# top_k is not a learnable param; read from metadata if present, else default 20
top_k = state.get("meta.top_k", torch.tensor(20)).item() if "meta.top_k" in state else 20
gate = NEGGate(init_threshold=gate_threshold, top_k=int(top_k)).to(
device=device, dtype=torch.float32
)
if gate_sd:
gate.load_state_dict(gate_sd)
gate.eval()
# Attach
base_model.neg_head = head
base_model.neg_gate = gate
# Wrap forward
original_forward = base_model.forward
def forward_with_neg(*args, **kwargs):
# Force hidden states capture
kwargs["output_hidden_states"] = True
out = original_forward(*args, **kwargs)
hidden_states = out.hidden_states
if hidden_states is None:
return out
last_hidden = hidden_states[-1][:, -1].float()
pred_ent = base_model.neg_head(last_hidden)
logits = out.logits
last_logits = logits[:, -1].float()
guided = base_model.neg_gate(last_logits, pred_ent)
# Clone and replace last position
new_logits = logits.clone()
new_logits[:, -1] = guided.to(logits.dtype)
out.logits = new_logits
return out
base_model.forward = forward_with_neg
base_model._neg_attached = True
print(f"[Darwin-NEG] NEG attached successfully.")
print(f"[Darwin-NEG] threshold = {gate.threshold.item():.4f}")
print(f"[Darwin-NEG] top_k = {gate.top_k}")
print(f"[Darwin-NEG] head params: {sum(p.numel() for p in head.parameters()):,}")
return base_model
def load_darwin_neg(repo_or_path, torch_dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True, hf_token=None, **kwargs):
"""Convenience loader: loads base model + attaches NEG in one call.
Example:
from modeling_darwin_neg import load_darwin_neg
model = load_darwin_neg("FINAL-Bench/Darwin-9B-NEG", hf_token="hf_...")
"""
from transformers import AutoModelForCausalLM
token = hf_token or os.environ.get("HF_TOKEN")
base = AutoModelForCausalLM.from_pretrained(
repo_or_path,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=trust_remote_code,
token=token,
low_cpu_mem_usage=True,
**kwargs,
)
return attach_neg(base, repo_or_path, hf_token=token)