| from __future__ import annotations |
|
|
| import re |
| from dataclasses import dataclass |
| from typing import Sequence |
|
|
| import torch |
| from torch import nn |
| from transformers.generation.utils import GenerationMixin |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils.generic import ModelOutput |
|
|
| from .config import CircuitGPTConfig |
| from .gpt import GPT |
| from .hook_utils import hook_recorder |
|
|
|
|
| @dataclass |
| class CircuitGPTCausalLMOutput(ModelOutput): |
| loss: torch.Tensor | None = None |
| logits: torch.Tensor | None = None |
| activations: dict[str, torch.Tensor] | None = None |
|
|
|
|
| def _activations_regex(keys: Sequence[str]) -> str: |
| escaped = (re.escape(k) for k in keys) |
| return "^(" + "|".join(escaped) + ")$" |
|
|
|
|
| class CircuitGPTPreTrainedModel(PreTrainedModel): |
| config_class = CircuitGPTConfig |
| base_model_prefix = "circuit_model" |
| circuit_model: GPT |
|
|
| def __init__(self, config: CircuitGPTConfig, *inputs, **kwargs) -> None: |
| super().__init__(config, *inputs, **kwargs) |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| return self.circuit_model.transformer.wte |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.circuit_model.transformer.wte = value |
|
|
| def get_output_embeddings(self) -> nn.Module: |
| return self.circuit_model.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
| self.circuit_model.lm_head = new_embeddings |
|
|
|
|
| class CircuitGPTForCausalLM(CircuitGPTPreTrainedModel, GenerationMixin): |
| """ |
| Hugging Face-compatible wrapper around `circuit_sparsity.gpt.GPT`. |
| All math happens inside the original module so parity is guaranteed. |
| """ |
|
|
| def __init__(self, config: CircuitGPTConfig, circuit_model: GPT | None = None) -> None: |
| super().__init__(config) |
|
|
| if circuit_model is None: |
| self.circuit_model = GPT(config.to_circuit_config()) |
| self.post_init() |
| else: |
| self.circuit_model = circuit_model |
|
|
| |
| |
| |
| @classmethod |
| def from_circuit_model(cls, circuit_model: GPT) -> "CircuitGPTForCausalLM": |
| config = CircuitGPTConfig.from_circuit_config(circuit_model.config) |
| return cls(config, circuit_model=circuit_model) |
|
|
| |
| |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| labels: torch.LongTensor | None = None, |
| output_activations: Sequence[str] | None = None, |
| return_dict: bool | None = None, |
| use_cache: bool | None = None, |
| output_attentions: bool | None = None, |
| output_hidden_states: bool | None = None, |
| **kwargs, |
| ) -> CircuitGPTCausalLMOutput: |
| |
| remaining_kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| if remaining_kwargs: |
| unsupported = ", ".join(remaining_kwargs.keys()) |
| raise ValueError(f"Unsupported arguments for CircuitGPTForCausalLM: {unsupported}") |
|
|
| if input_ids.size(-1) > self.config.block_size: |
| raise ValueError( |
| f"Sequence length {input_ids.size(-1)} exceeds block size {self.config.block_size}" |
| ) |
|
|
| if output_activations: |
| regex = _activations_regex(output_activations) |
| with hook_recorder(regex=regex) as recorded: |
| logits, loss, _ = self.circuit_model(input_ids, targets=labels) |
| activations = {key: recorded[key] for key in output_activations if key in recorded} |
| else: |
| activations = None |
| logits, loss, _ = self.circuit_model(input_ids, targets=labels) |
|
|
| if labels is None: |
| loss = None |
|
|
| return CircuitGPTCausalLMOutput( |
| loss=loss, |
| logits=logits, |
| activations=activations, |
| ) |
|
|
| |
| |
| |
| def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs): |
| if input_ids.size(-1) > self.config.block_size: |
| input_ids = input_ids[:, -self.config.block_size :] |
| return {"input_ids": input_ids} |
|
|
| def _reorder_cache(self, past, beam_idx): |
| |
| return past |
|
|