""" Darwin-9B-NEG — Native Entropy Gating enabled model. Helper module to attach NEG (Native Entropy Gating) to a Darwin base model. Provides: - NEGHead : predicts per-token entropy from last hidden state - NEGGate : non-monotonic top-k logit masking (effective in greedy decoding) - attach_neg(model, path_or_repo) : monkey-patches forward to apply NEG See README.md for usage. """ import os import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import load_file class NEGHead(nn.Module): """NEG-Head: predicts entropy of next-token distribution. Input: hidden_state [B, H] Output: predicted_entropy [B] (>= 0 via softplus) """ def __init__(self, hidden: int, dropout: float = 0.1): super().__init__() self.proj_down = nn.Linear(hidden, hidden // 4) self.act = nn.GELU() self.dropout = nn.Dropout(dropout) self.proj_out = nn.Linear(hidden // 4, 1) def forward(self, h): x = self.proj_down(h) x = self.act(x) x = self.dropout(x) return F.softplus(self.proj_out(x).squeeze(-1)) class NEGGate(nn.Module): """NEG-Gate: top-k logit masking (non-monotonic). When predicted_entropy > threshold, restrict logits to top-k candidates. This changes argmax (non-monotonic), making NEG effective in greedy decoding. """ def __init__(self, init_threshold: float = 1.175, top_k: int = 20): super().__init__() self.threshold = nn.Parameter(torch.tensor(init_threshold)) self.top_k = top_k def forward(self, logits, predicted_entropy): activate = (predicted_entropy > self.threshold).float().unsqueeze(-1) if activate.sum() == 0: return logits top_k_vals, top_k_idx = logits.topk(self.top_k, dim=-1) masked = torch.full_like(logits, float('-inf')) masked.scatter_(-1, top_k_idx, top_k_vals) return logits * (1 - activate) + masked * activate def attach_neg(base_model, neg_path_or_repo, hf_token=None): """Attach NEG to a loaded base model. Args: base_model: a HuggingFace AutoModelForCausalLM instance neg_path_or_repo: local path or HF repo containing neg_modules.safetensors hf_token: optional HF token (for private repos) Returns: The same model with NEG-Head and NEG-Gate attached and forward() wrapped to apply NEG at each generation step. """ # Find neg_modules.safetensors neg_file = None if os.path.isdir(neg_path_or_repo): candidate = os.path.join(neg_path_or_repo, "neg_modules.safetensors") if os.path.exists(candidate): neg_file = candidate if neg_file is None: try: from huggingface_hub import hf_hub_download neg_file = hf_hub_download( repo_id=neg_path_or_repo, filename="neg_modules.safetensors", token=hf_token or os.environ.get("HF_TOKEN"), ) except Exception as e: raise FileNotFoundError( f"Cannot locate neg_modules.safetensors at {neg_path_or_repo}: {e}" ) # Determine hidden size and device hidden_size = getattr(base_model.config, "hidden_size", None) if hidden_size is None: hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None) if hidden_size is None: raise ValueError("Could not determine hidden_size from model config.") device = next(base_model.parameters()).device # Load state dict state = load_file(neg_file) head_sd = {k.replace("head.", "", 1): v for k, v in state.items() if k.startswith("head.")} gate_sd = {k.replace("gate.", "", 1): v for k, v in state.items() if k.startswith("gate.")} # Build and load NEG modules head = NEGHead(hidden_size).to(device=device, dtype=torch.float32) if head_sd: head.load_state_dict(head_sd) head.eval() # Infer gate params from state gate_threshold = gate_sd.get("threshold", torch.tensor(1.175)).item() # top_k is not a learnable param; read from metadata if present, else default 20 top_k = state.get("meta.top_k", torch.tensor(20)).item() if "meta.top_k" in state else 20 gate = NEGGate(init_threshold=gate_threshold, top_k=int(top_k)).to( device=device, dtype=torch.float32 ) if gate_sd: gate.load_state_dict(gate_sd) gate.eval() # Attach base_model.neg_head = head base_model.neg_gate = gate # Wrap forward original_forward = base_model.forward def forward_with_neg(*args, **kwargs): # Force hidden states capture kwargs["output_hidden_states"] = True out = original_forward(*args, **kwargs) hidden_states = out.hidden_states if hidden_states is None: return out last_hidden = hidden_states[-1][:, -1].float() pred_ent = base_model.neg_head(last_hidden) logits = out.logits last_logits = logits[:, -1].float() guided = base_model.neg_gate(last_logits, pred_ent) # Clone and replace last position new_logits = logits.clone() new_logits[:, -1] = guided.to(logits.dtype) out.logits = new_logits return out base_model.forward = forward_with_neg base_model._neg_attached = True print(f"[Darwin-NEG] NEG attached successfully.") print(f"[Darwin-NEG] threshold = {gate.threshold.item():.4f}") print(f"[Darwin-NEG] top_k = {gate.top_k}") print(f"[Darwin-NEG] head params: {sum(p.numel() for p in head.parameters()):,}") return base_model def load_darwin_neg(repo_or_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, hf_token=None, **kwargs): """Convenience loader: loads base model + attaches NEG in one call. Example: from modeling_darwin_neg import load_darwin_neg model = load_darwin_neg("FINAL-Bench/Darwin-9B-NEG", hf_token="hf_...") """ from transformers import AutoModelForCausalLM token = hf_token or os.environ.get("HF_TOKEN") base = AutoModelForCausalLM.from_pretrained( repo_or_path, torch_dtype=torch_dtype, device_map=device_map, trust_remote_code=trust_remote_code, token=token, low_cpu_mem_usage=True, **kwargs, ) return attach_neg(base, repo_or_path, hf_token=token)