connaaa commited on
Commit
378744c
·
verified ·
1 Parent(s): 8a1fc85

Phase 1 release: InterpGPT matched-pair checkpoint

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: transformers
4
+ tags:
5
+ - interpretability
6
+ - mechanistic-interpretability
7
+ - task-decomposition
8
+ - small-language-model
9
+ - transformer-lens
10
+ pipeline_tag: text-generation
11
+ ---
12
+
13
+ # InterpGPT — Standard Model (23M)
14
+
15
+ Part of the **InterpGPT** matched-pair release. This is the **standard** model;
16
+ its counterpart is [`connaaa/interpgpt-adhd-23M`](https://huggingface.co/connaaa/interpgpt-adhd-23M).
17
+ Both models share identical architecture and training recipe; only the training
18
+ data distribution differs.
19
+
20
+ | | Value |
21
+ |---|---|
22
+ | Parameters | 23,471,104 |
23
+ | Layers | 6 |
24
+ | Heads | 8 |
25
+ | d_model | 512 |
26
+ | d_head | 64 |
27
+ | d_mlp (SwiGLU) | 1408 |
28
+ | Vocab | 8192 (custom BPE) |
29
+ | Context length | 512 |
30
+ | Norm | RMSNorm (ε = 1e-6) |
31
+ | Position | RoPE (half-half, base 10,000) |
32
+ | Activation | SwiGLU |
33
+ | Biases | none |
34
+ | Tied input/output embeddings | yes |
35
+ | Training tokens | ~25k steps on task-decomposition corpus |
36
+
37
+ ## What is this model for?
38
+
39
+ Given a task prompt, the model writes a step-by-step decomposition. The
40
+ **standard** variant was trained on normal task decompositions (tasks → subtasks
41
+ in straightforward order). The **ADHD** counterpart was trained on decompositions
42
+ with smaller steps and interleaved micro-regulation actions (e.g. "sip water",
43
+ "deep breath", "quick stretch").
44
+
45
+ The pair is the subject of a mechanistic-interpretability study.
46
+ Phase 1 headline findings:
47
+
48
+ - **Structural head-position swap.** A step-layout-broadcast head lives at
49
+ **L3H0** in the standard model and at **L3H5** in the ADHD model.
50
+ Cross-model per-position attention profile cosine similarity is **0.997**
51
+ at the matched (different-index) pair vs a same-index baseline of **0.66**.
52
+ - **Block-2 content circuit.** P(regulation token) at step-onset positions jumps
53
+ 17× between layer 1 and layer 2 in the ADHD model (0.014 → 0.251); the
54
+ standard model never crosses 1% at any layer.
55
+ - **High-specificity null-steering SAE feature.** See the companion SAE repo
56
+ [`connaaa/interpgpt-sae-phase5`](https://huggingface.co/connaaa/interpgpt-sae-phase5).
57
+
58
+ ## Input format
59
+
60
+ ```
61
+ <|task|>Clean the kitchen<|steps|>Step 1 text<|sep|>Step 2 text<|sep|>...<|end|>
62
+ ```
63
+
64
+ ## Loading
65
+
66
+ ### HuggingFace Transformers (custom code)
67
+
68
+ ```python
69
+ from transformers import AutoModel, AutoTokenizer
70
+ model = AutoModel.from_pretrained(
71
+ "connaaa/interpgpt-standard-23M", trust_remote_code=True
72
+ )
73
+ tokenizer = AutoTokenizer.from_pretrained(
74
+ "connaaa/interpgpt-standard-23M"
75
+ )
76
+ ```
77
+
78
+ ### TransformerLens (recommended for interpretability)
79
+
80
+ The repo ships a TransformerLens-compatible bundle at `hooked_transformer.pt`:
81
+
82
+ ```python
83
+ from huggingface_hub import hf_hub_download
84
+ from transformer_lens import HookedTransformer, HookedTransformerConfig
85
+ import torch
86
+
87
+ path = hf_hub_download(
88
+ "connaaa/interpgpt-standard-23M", "hooked_transformer.pt"
89
+ )
90
+ blob = torch.load(path, map_location="cpu", weights_only=False)
91
+ cfg_keep = {
92
+ k: v for k, v in blob["config"].items()
93
+ if k in HookedTransformerConfig.__dataclass_fields__ and not (
94
+ isinstance(v, str) and v.startswith("torch.")
95
+ )
96
+ }
97
+ cfg = HookedTransformerConfig(**cfg_keep)
98
+ model = HookedTransformer(cfg)
99
+ model.load_state_dict(blob["model_state_dict"])
100
+ model.eval()
101
+ ```
102
+
103
+ ### Raw PyTorch / original TaskGPT class
104
+
105
+ ```python
106
+ # Pairs with gpt_model.py from https://github.com/cwklurks/interpgpt
107
+ from huggingface_hub import hf_hub_download
108
+ from gpt_model import GPTConfig, TaskGPT
109
+ import torch
110
+
111
+ path = hf_hub_download(
112
+ "connaaa/interpgpt-standard-23M", "pytorch_model.pt"
113
+ )
114
+ blob = torch.load(path, map_location="cpu", weights_only=False)
115
+ model = TaskGPT(GPTConfig(**blob["config"]))
116
+ model.load_state_dict(blob["model_state_dict"])
117
+ ```
118
+
119
+ ## Reproduce the head-swap finding
120
+
121
+ Open the companion Colab:
122
+ **`notebooks/InterpGPT_HeadSwap.ipynb`** at
123
+ [github.com/cwklurks/interpgpt](https://github.com/cwklurks/interpgpt).
124
+ End-to-end run on Colab free tier reproduces the 0.997 vs 0.66 comparison
125
+ in under 15 minutes.
126
+
127
+ ## Training data
128
+
129
+ Custom task-decomposition corpus, two variants (standard vs ADHD) generated
130
+ with the same task pool. Detailed dataset notes + generation scripts live in
131
+ the main repo (`preprocess.py`, `merge_data.py`, `rebuild_data.py`,
132
+ `fix_adhd_data.py`, `shorten_adhd_steps.py`).
133
+
134
+ ## License
135
+
136
+ MIT.
137
+
138
+ ## Intended use
139
+
140
+ Interpretability research. The model is intentionally small and
141
+ domain-specific; **not** intended as a general-purpose chatbot.
142
+
143
+ ## Citation
144
+
145
+ ```bibtex
146
+ @misc{interpgpt2026,
147
+ title = {{InterpGPT}: A matched-pair interpretability study of task-decomposition models},
148
+ author = {Klann, Connor},
149
+ year = {2026},
150
+ url = {https://github.com/cwklurks/interpgpt}
151
+ }
152
+ ```
config.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 8197,
3
+ "max_seq_len": 512,
4
+ "n_layers": 6,
5
+ "n_heads": 8,
6
+ "d_model": 512,
7
+ "d_ff": 2048,
8
+ "dropout": 0.25,
9
+ "pad_id": 8196,
10
+ "bias": false,
11
+ "variant": "standard",
12
+ "return_dict": true,
13
+ "output_hidden_states": false,
14
+ "torchscript": false,
15
+ "dtype": null,
16
+ "pruned_heads": {},
17
+ "tie_word_embeddings": true,
18
+ "chunk_size_feed_forward": 0,
19
+ "is_encoder_decoder": false,
20
+ "is_decoder": false,
21
+ "cross_attention_hidden_size": null,
22
+ "add_cross_attention": false,
23
+ "tie_encoder_decoder": false,
24
+ "architectures": null,
25
+ "finetuning_task": null,
26
+ "id2label": {
27
+ "0": "LABEL_0",
28
+ "1": "LABEL_1"
29
+ },
30
+ "label2id": {
31
+ "LABEL_0": 0,
32
+ "LABEL_1": 1
33
+ },
34
+ "task_specific_params": null,
35
+ "problem_type": null,
36
+ "tokenizer_class": null,
37
+ "prefix": null,
38
+ "bos_token_id": null,
39
+ "pad_token_id": 8196,
40
+ "eos_token_id": null,
41
+ "sep_token_id": null,
42
+ "decoder_start_token_id": null,
43
+ "max_length": 20,
44
+ "min_length": 0,
45
+ "do_sample": false,
46
+ "early_stopping": false,
47
+ "num_beams": 1,
48
+ "temperature": 1.0,
49
+ "top_k": 50,
50
+ "top_p": 1.0,
51
+ "typical_p": 1.0,
52
+ "repetition_penalty": 1.0,
53
+ "length_penalty": 1.0,
54
+ "no_repeat_ngram_size": 0,
55
+ "encoder_no_repeat_ngram_size": 0,
56
+ "bad_words_ids": null,
57
+ "num_return_sequences": 1,
58
+ "output_scores": false,
59
+ "return_dict_in_generate": false,
60
+ "forced_bos_token_id": null,
61
+ "forced_eos_token_id": null,
62
+ "remove_invalid_values": false,
63
+ "exponential_decay_length_penalty": null,
64
+ "suppress_tokens": null,
65
+ "begin_suppress_tokens": null,
66
+ "num_beam_groups": 1,
67
+ "diversity_penalty": 0.0,
68
+ "_name_or_path": "",
69
+ "transformers_version": "4.57.6",
70
+ "auto_map": {
71
+ "AutoConfig": "configuration_interpgpt.InterpGPTConfig",
72
+ "AutoModel": "modeling_interpgpt.InterpGPTModel"
73
+ },
74
+ "tf_legacy_loss": false,
75
+ "use_bfloat16": false,
76
+ "model_type": "interpgpt",
77
+ "output_attentions": false
78
+ }
configuration_interpgpt.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace PretrainedConfig for InterpGPT / TaskGPT.
3
+
4
+ Mirrors gpt_model.GPTConfig but subclasses transformers.PretrainedConfig
5
+ so `AutoConfig` / `AutoModel.from_pretrained(..., trust_remote_code=True)` work.
6
+ """
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class InterpGPTConfig(PretrainedConfig):
11
+ model_type = "interpgpt"
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 8192,
16
+ max_seq_len: int = 512,
17
+ n_layers: int = 6,
18
+ n_heads: int = 8,
19
+ d_model: int = 512,
20
+ d_ff: int = 2048,
21
+ dropout: float = 0.1,
22
+ pad_id: int = 0,
23
+ bias: bool = False,
24
+ variant: str = "standard",
25
+ **kwargs,
26
+ ):
27
+ self.vocab_size = vocab_size
28
+ self.max_seq_len = max_seq_len
29
+ self.n_layers = n_layers
30
+ self.n_heads = n_heads
31
+ self.d_model = d_model
32
+ self.d_ff = d_ff
33
+ self.dropout = dropout
34
+ self.pad_id = pad_id
35
+ self.bias = bias
36
+ self.variant = variant
37
+ super().__init__(pad_token_id=pad_id, **kwargs)
hooked_transformer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9dc5f2f4adf6c3faad817abe5e41853e177b96891646438a588bf0e6b81a106
3
+ size 113993831
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cdcf5277237808f7ed6e4ae712d58dbef6b8ed0ba113a789c26dcdc4743067f
3
+ size 100969696
modeling_interpgpt.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace PreTrainedModel wrapper for InterpGPT / TaskGPT.
3
+
4
+ Weights map 1:1 to the original gpt_model.TaskGPT state dict, so the same
5
+ .pt checkpoints produced during Phase 1 load here without remapping.
6
+
7
+ Usage (after upload):
8
+ from transformers import AutoModel, AutoTokenizer
9
+ model = AutoModel.from_pretrained("connaaa/interpgpt-standard-23M",
10
+ trust_remote_code=True)
11
+ # Or for the analysis pipeline:
12
+ from transformer_lens import HookedTransformer
13
+ hooked = HookedTransformer.from_pretrained("connaaa/interpgpt-standard-23M",
14
+ hf_model=model,
15
+ ...)
16
+ """
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers import PreTrainedModel
22
+
23
+ from .configuration_interpgpt import InterpGPTConfig
24
+
25
+
26
+ class RMSNorm(nn.Module):
27
+ def __init__(self, d_model: int, eps: float = 1e-6):
28
+ super().__init__()
29
+ self.weight = nn.Parameter(torch.ones(d_model))
30
+ self.eps = eps
31
+
32
+ def forward(self, x):
33
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
34
+ return x * norm * self.weight
35
+
36
+
37
+ class RotaryPositionalEncoding(nn.Module):
38
+ def __init__(self, d_model: int, max_seq_len: int = 512, base: float = 10000.0):
39
+ super().__init__()
40
+ assert d_model % 2 == 0
41
+ inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
42
+ self.register_buffer("inv_freq", inv_freq)
43
+ t = torch.arange(max_seq_len, dtype=torch.float)
44
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
45
+ self.register_buffer("cos_cached", freqs.cos())
46
+ self.register_buffer("sin_cached", freqs.sin())
47
+
48
+ def forward(self, seq_len: int):
49
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
50
+
51
+
52
+ def apply_rotary_emb(x, cos, sin):
53
+ d_half = x.shape[-1] // 2
54
+ x1, x2 = x[..., :d_half], x[..., d_half:]
55
+ cos = cos[: x.shape[2]].unsqueeze(0).unsqueeze(0)
56
+ sin = sin[: x.shape[2]].unsqueeze(0).unsqueeze(0)
57
+ out1 = x1 * cos - x2 * sin
58
+ out2 = x2 * cos + x1 * sin
59
+ return torch.cat([out1, out2], dim=-1)
60
+
61
+
62
+ class CausalSelfAttention(nn.Module):
63
+ def __init__(self, config: InterpGPTConfig):
64
+ super().__init__()
65
+ assert config.d_model % config.n_heads == 0
66
+ self.n_heads = config.n_heads
67
+ self.head_dim = config.d_model // config.n_heads
68
+ self.qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
69
+ self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
70
+ self.attn_dropout = nn.Dropout(config.dropout)
71
+ self.resid_dropout = nn.Dropout(config.dropout)
72
+ self.rope = RotaryPositionalEncoding(self.head_dim, config.max_seq_len)
73
+ mask = torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
74
+ self.register_buffer("causal_mask", mask.view(1, 1, config.max_seq_len, config.max_seq_len))
75
+
76
+ def forward(self, x, kv_cache=None):
77
+ B, T, D = x.shape
78
+ qkv = self.qkv(x)
79
+ q, k, v = qkv.chunk(3, dim=-1)
80
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
81
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
82
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
83
+ cos, sin = self.rope(T)
84
+ q = apply_rotary_emb(q, cos, sin)
85
+ k = apply_rotary_emb(k, cos, sin)
86
+ if kv_cache is not None:
87
+ if "k" in kv_cache:
88
+ k = torch.cat([kv_cache["k"], k], dim=2)
89
+ v = torch.cat([kv_cache["v"], v], dim=2)
90
+ kv_cache["k"] = k
91
+ kv_cache["v"] = v
92
+ if hasattr(F, "scaled_dot_product_attention") and kv_cache is None:
93
+ out = F.scaled_dot_product_attention(
94
+ q, k, v,
95
+ attn_mask=None,
96
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
97
+ is_causal=True,
98
+ )
99
+ else:
100
+ scale = 1.0 / math.sqrt(self.head_dim)
101
+ attn = torch.matmul(q, k.transpose(-2, -1)) * scale
102
+ T_k = k.size(2)
103
+ causal = self.causal_mask[:, :, T_k - T : T_k, :T_k]
104
+ attn = attn.masked_fill(causal == 0, float("-inf"))
105
+ attn = F.softmax(attn, dim=-1)
106
+ attn = self.attn_dropout(attn)
107
+ out = torch.matmul(attn, v)
108
+ out = out.transpose(1, 2).contiguous().view(B, T, D)
109
+ return self.resid_dropout(self.out_proj(out))
110
+
111
+
112
+ class FeedForward(nn.Module):
113
+ def __init__(self, config: InterpGPTConfig):
114
+ super().__init__()
115
+ hidden = int(2 * config.d_ff / 3)
116
+ hidden = 64 * ((hidden + 63) // 64)
117
+ self.gate_proj = nn.Linear(config.d_model, hidden, bias=config.bias)
118
+ self.up_proj = nn.Linear(config.d_model, hidden, bias=config.bias)
119
+ self.down_proj = nn.Linear(hidden, config.d_model, bias=config.bias)
120
+ self.dropout = nn.Dropout(config.dropout)
121
+
122
+ def forward(self, x):
123
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
124
+
125
+
126
+ class TransformerBlock(nn.Module):
127
+ def __init__(self, config: InterpGPTConfig):
128
+ super().__init__()
129
+ self.ln1 = RMSNorm(config.d_model)
130
+ self.attn = CausalSelfAttention(config)
131
+ self.ln2 = RMSNorm(config.d_model)
132
+ self.ffn = FeedForward(config)
133
+
134
+ def forward(self, x, kv_cache=None):
135
+ x = x + self.attn(self.ln1(x), kv_cache)
136
+ x = x + self.ffn(self.ln2(x))
137
+ return x
138
+
139
+
140
+ class InterpGPTModel(PreTrainedModel):
141
+ """
142
+ HF-wrapped InterpGPT / TaskGPT. State dict parameter names match the
143
+ original gpt_model.TaskGPT exactly so Phase 1 .pt checkpoints load
144
+ via state_dict without remapping.
145
+ """
146
+ config_class = InterpGPTConfig
147
+ base_model_prefix = "interpgpt"
148
+ supports_gradient_checkpointing = False
149
+
150
+ def __init__(self, config: InterpGPTConfig):
151
+ super().__init__(config)
152
+ self.config = config
153
+ self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id)
154
+ self.drop = nn.Dropout(config.dropout)
155
+ self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
156
+ self.ln_final = RMSNorm(config.d_model)
157
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
158
+ self.lm_head.weight = self.token_embedding.weight
159
+ self.post_init()
160
+
161
+ def _init_weights(self, module):
162
+ if isinstance(module, nn.Linear):
163
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
164
+ if module.bias is not None:
165
+ nn.init.zeros_(module.bias)
166
+ elif isinstance(module, nn.Embedding):
167
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
168
+ if module.padding_idx is not None:
169
+ nn.init.zeros_(module.weight[module.padding_idx])
170
+
171
+ def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None, **kwargs):
172
+ B, T = input_ids.shape
173
+ x = self.drop(self.token_embedding(input_ids))
174
+ for block in self.blocks:
175
+ x = block(x)
176
+ x = self.ln_final(x)
177
+ logits = self.lm_head(x)
178
+ output = {"logits": logits}
179
+ if labels is not None:
180
+ shift_logits = logits[:, :-1].contiguous()
181
+ shift_labels = labels[:, 1:].contiguous()
182
+ loss = F.cross_entropy(
183
+ shift_logits.view(-1, self.config.vocab_size),
184
+ shift_labels.view(-1),
185
+ ignore_index=self.config.pad_id,
186
+ reduction="none",
187
+ ).view(B, T - 1)
188
+ if loss_mask is not None:
189
+ shift_mask = loss_mask[:, 1:].contiguous().float()
190
+ loss = (loss * shift_mask).sum() / shift_mask.sum().clamp(min=1.0)
191
+ else:
192
+ loss = loss.mean()
193
+ output["loss"] = loss
194
+ return output
pytorch_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:905dda81b9380851c52374cbb089646c4cd1cd67961ba3eb8a8b9f24f514ca3d
3
+ size 288797081
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff