| """
|
| LM-eval harness wrapper for Circuit/Mirrored transformers.
|
|
|
| Usage:
|
| # Single model
|
| python -m circuits.bench --checkpoint circuits/checkpoints/mirrored/best.pt --gpu 0
|
|
|
| # Compare all architectures
|
| python -m circuits.bench --compare --gpu 0
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import List
|
| from tqdm import tqdm
|
| from lm_eval.api.model import LM
|
| from lm_eval.api.instance import Instance
|
|
|
| from .config import CircuitConfig
|
| from .model import CircuitTransformer
|
| from .mirrored import MirroredConfig, MirroredTransformer
|
| from .graft_g2lu import load_g2lu_model
|
| from .layers import build_word_start_table, compute_word_positions
|
| from .data import get_tokenizer
|
|
|
| def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
|
| """Migrate checkpoint state_dict to match current model architecture.
|
|
|
| Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
|
| """
|
| if any(k.startswith("_orig_mod.") for k in state_dict):
|
| state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
|
|
| model_keys = set(model.state_dict().keys())
|
| ckpt_keys = set(state_dict.keys())
|
|
|
| missing = model_keys - ckpt_keys
|
| unexpected = ckpt_keys - model_keys
|
|
|
| print(unexpected)
|
|
|
| if not missing and not unexpected:
|
| return state_dict
|
|
|
| migrated = dict(state_dict)
|
| migrations = []
|
|
|
|
|
| for key in list(unexpected):
|
| if ".ffn.gate_expand.weight" in key:
|
| new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
|
| if new_key in missing:
|
| migrated[new_key] = migrated.pop(key)
|
| missing.discard(new_key)
|
| unexpected.discard(key)
|
| migrations.append(f" {key} → {new_key}")
|
| if ".ffn.gate_compress.weight" in key:
|
| new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
|
| if new_key in missing:
|
| migrated[new_key] = migrated.pop(key)
|
| missing.discard(new_key)
|
| unexpected.discard(key)
|
| migrations.append(f" {key} → {new_key}")
|
|
|
| if migrations:
|
| print(f"State dict migration ({len(migrations)} keys renamed):")
|
| for m in migrations:
|
| print(m)
|
|
|
| still_missing = model_keys - set(migrated.keys())
|
| if still_missing:
|
| print(f" New parameters (freshly initialized): {len(still_missing)}")
|
| for k in sorted(still_missing):
|
| print(f" {k}")
|
|
|
| return migrated
|
|
|
| def load_model(checkpoint_path: str, device: str = "cuda"):
|
| """Load any circuit model from checkpoint with auto-detection."""
|
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
| model_type = checkpoint.get("model_type", "standard")
|
| if model_type == "graft_g2lu":
|
| model = load_g2lu_model(checkpoint_path, device=device)
|
| model.eval()
|
| n_layers = len(model.g2lu_mlps)
|
| arch_name = f"G²LU Graft ({checkpoint['pretrained_name']}, {n_layers}L)"
|
| config = model.model.config
|
| return model, config, arch_name, model_type
|
| elif model_type == "mirrored":
|
| if checkpoint["config"].get("dual_gate_middle"):
|
| checkpoint["config"].pop("dual_gate_middle")
|
| config = MirroredConfig.from_dict(checkpoint["config"])
|
| model = MirroredTransformer(config)
|
| arch_name = f"Mirrored ({model.total_virtual_layers}L)"
|
| else:
|
| config = CircuitConfig.from_dict(checkpoint["config"])
|
| model = CircuitTransformer(config)
|
| arch_name = f"Standard ({config.num_layers}L)"
|
|
|
|
|
| state_dict = checkpoint["model"]
|
| state_dict = _migrate_state_dict(state_dict, model)
|
| model.load_state_dict(state_dict)
|
|
|
| model = model.to(device).eval()
|
| return model, config, arch_name, model_type
|
|
|
|
|
| class CircuitLM(LM):
|
| """LM-eval wrapper for Circuit transformer family."""
|
|
|
| def __init__(
|
| self,
|
| checkpoint: str,
|
| device: str = "cuda",
|
| batch_size: int = 1,
|
| compile: bool = False,
|
| ):
|
| super().__init__()
|
|
|
| self.model, self.config, self.arch_name, self.model_type = load_model(
|
| checkpoint, device
|
| )
|
|
|
| self._raw_model = self.model
|
| if compile == True:
|
| self.model = torch.compile(self.model)
|
| print(" torch.compile: enabled")
|
| _ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
|
| _tok_name = _ckpt.get("tokenizer_name", "gpt2")
|
| del _ckpt
|
| self.tokenizer = get_tokenizer(_tok_name)
|
| if self.tokenizer.pad_token is None:
|
| self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
| self._device = device
|
| self._batch_size = batch_size
|
|
|
|
|
| self._word_start_table = None
|
| word_rope_dims = getattr(self.config, 'word_rope_dims', 0)
|
| if word_rope_dims == 0 and isinstance(self.config, dict):
|
| word_rope_dims = self.config.get('word_rope_dims', 0)
|
| if word_rope_dims > 0:
|
| self._word_start_table = build_word_start_table(
|
| self.tokenizer, len(self.tokenizer)
|
| ).to(device)
|
| print(f" Word-position RoPE: {word_rope_dims} dims")
|
|
|
|
|
| n_params = sum(p.numel() for p in self.model.parameters())
|
| print(f" Architecture: {self.arch_name}")
|
| print(f" Parameters: {n_params / 1e6:.1f}M")
|
|
|
| @property
|
| def eot_token_id(self):
|
| return self.tokenizer.eos_token_id
|
|
|
| @property
|
| def max_length(self):
|
| return getattr(self.config, "max_seq_len", None) or getattr(self.config, "max_position_embeddings", 512)
|
|
|
| @property
|
| def max_gen_toks(self):
|
| return 256
|
|
|
| @property
|
| def batch_size(self):
|
| return self._batch_size
|
|
|
| @property
|
| def device(self):
|
| return self._device
|
|
|
| def tok_encode(self, string: str) -> List[int]:
|
| return self.tokenizer.encode(string, add_special_tokens=False)
|
|
|
| def tok_decode(self, tokens: List[int]) -> str:
|
| return self.tokenizer.decode(tokens)
|
|
|
| def _model_call(self, input_ids: torch.Tensor):
|
| with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16, enabled=self._device != "cpu"):
|
| word_positions = None
|
| if self._word_start_table is not None:
|
| word_positions = compute_word_positions(input_ids, self._word_start_table)
|
| output = self.model(input_ids, use_cache=False, word_positions=word_positions)
|
| return output["logits"]
|
|
|
| def _loglikelihood_tokens(self, requests, disable_tqdm=False):
|
| results = []
|
| for context_enc, continuation_enc in requests:
|
|
|
| full_enc = context_enc + continuation_enc
|
| if len(full_enc) > self.max_length:
|
| excess = len(full_enc) - self.max_length
|
| context_enc = context_enc[excess:]
|
| full_enc = context_enc + continuation_enc
|
|
|
| input_ids = torch.tensor(
|
| [full_enc], dtype=torch.long, device=self._device
|
| )
|
|
|
| logits = self._model_call(input_ids)
|
|
|
| ctx_len = len(context_enc)
|
| cont_logits = logits[:, ctx_len - 1 : -1, :]
|
| cont_tokens = input_ids[:, ctx_len:]
|
|
|
| log_probs = F.log_softmax(cont_logits, dim=-1)
|
| token_log_probs = log_probs.gather(
|
| 2, cont_tokens.unsqueeze(-1)
|
| ).squeeze(-1)
|
|
|
| total_log_prob = token_log_probs.sum().item()
|
| is_greedy = (cont_logits.argmax(dim=-1) == cont_tokens).all().item()
|
|
|
| results.append((total_log_prob, is_greedy))
|
|
|
| return results
|
|
|
| def loglikelihood(
|
| self, requests: List[Instance], disable_tqdm: bool = False
|
| ) -> List[tuple]:
|
| results = []
|
| for request in tqdm(
|
| requests, desc="loglikelihood", disable=disable_tqdm
|
| ):
|
| context, continuation = request.args
|
|
|
|
|
|
|
| context_enc = self.tok_encode(context)
|
| full_enc = self.tok_encode(context + continuation)
|
| continuation_enc = full_enc[len(context_enc):]
|
| if not continuation_enc:
|
|
|
|
|
| continuation_enc = self.tok_encode(continuation)
|
| result = self._loglikelihood_tokens([(context_enc, continuation_enc)])
|
| results.append(result[0])
|
| return results
|
|
|
| def loglikelihood_rolling(
|
| self, requests: List[Instance], disable_tqdm: bool = False
|
| ) -> List[float]:
|
| results = []
|
| for request in tqdm(
|
| requests, desc="loglikelihood_rolling", disable=disable_tqdm
|
| ):
|
| text = request.args[0]
|
| encoding = self.tok_encode(text)
|
|
|
| total_log_prob = 0.0
|
| max_len = self.max_length
|
|
|
| for i in range(0, len(encoding), max_len):
|
| chunk = encoding[i : i + max_len]
|
| input_ids = torch.tensor(
|
| [chunk], dtype=torch.long, device=self._device
|
| )
|
|
|
| logits = self._model_call(input_ids)
|
| shift_logits = logits[:, :-1, :]
|
| shift_labels = input_ids[:, 1:]
|
|
|
| log_probs = F.log_softmax(shift_logits, dim=-1)
|
| token_log_probs = log_probs.gather(
|
| 2, shift_labels.unsqueeze(-1)
|
| ).squeeze(-1)
|
|
|
| total_log_prob += token_log_probs.sum().item()
|
|
|
| results.append(total_log_prob)
|
| return results
|
|
|
| def generate_until(
|
| self, requests: List[Instance], disable_tqdm: bool = False
|
| ) -> List[str]:
|
| results = []
|
| for request in tqdm(
|
| requests, desc="generate_until", disable=disable_tqdm
|
| ):
|
| context = request.args[0]
|
| gen_kwargs = getattr(request, "kwargs", {}) or {}
|
|
|
| until = gen_kwargs.get("until", [self.tokenizer.eos_token])
|
| max_gen = gen_kwargs.get("max_gen_toks", self.max_gen_toks)
|
|
|
| context_enc = self.tok_encode(context)
|
|
|
| if len(context_enc) > self.max_length - max_gen:
|
| context_enc = context_enc[-(self.max_length - max_gen) :]
|
| input_ids = torch.tensor(
|
| [context_enc], dtype=torch.long, device=self._device
|
| )
|
|
|
| if self.model_type == "graft_g2lu":
|
|
|
|
|
| with torch.no_grad():
|
| output_ids = self._raw_model.generate(
|
| input_ids,
|
| max_new_tokens=max_gen,
|
| do_sample=False,
|
| use_cache=True,
|
| )
|
| generated_text = self.tok_decode(
|
| output_ids[0, input_ids.shape[1] :].tolist()
|
| )
|
| else:
|
| generated_ids = input_ids.clone()
|
| with torch.no_grad():
|
| for _ in range(max_gen):
|
|
|
| if generated_ids.shape[1] > self.max_length:
|
| generated_ids = generated_ids[:, -self.max_length :]
|
|
|
| logits = self._model_call(generated_ids)
|
| next_logits = logits[:, -1, :]
|
| next_token = next_logits.argmax(dim=-1, keepdim=True)
|
| generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
|
|
| if next_token.item() == self.eot_token_id:
|
| break
|
|
|
| current_text = self.tok_decode(
|
| generated_ids[0, len(context_enc) :].tolist()
|
| )
|
| if any(s in current_text for s in until):
|
| break
|
|
|
| generated_text = self.tok_decode(
|
| generated_ids[0, len(context_enc) :].tolist()
|
| )
|
|
|
| for stop in until:
|
| if stop in generated_text:
|
| generated_text = generated_text[: generated_text.index(stop)]
|
|
|
| results.append(generated_text)
|
|
|
| return results
|
|
|