affine-recoded-minimal-code-table-free / model_binary_affine_code_n_layer_32.py
E6E831728's picture
Upload folder using huggingface_hub
3a37bc0 verified
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithCrossAttentions
def gf2_rank(M: torch.Tensor) -> int:
"""
Rank over GF(2).
M: [n, m] tensor with 0/1 entries.
"""
M = (M.clone().to(torch.uint8) & 1)
n_rows, n_cols = M.shape
rank = 0
for col in range(n_cols):
pivot = None
for r in range(rank, n_rows):
if M[r, col].item():
pivot = r
break
if pivot is None:
continue
if pivot != rank:
tmp = M[rank].clone()
M[rank] = M[pivot]
M[pivot] = tmp
for r in range(n_rows):
if r != rank and M[r, col].item():
M[r] ^= M[rank]
rank += 1
if rank == n_rows:
break
return rank
def gf2_inverse(A: torch.Tensor) -> torch.Tensor:
"""
Inverse over GF(2).
A: [n, n] with 0/1 entries, invertible over GF(2).
"""
A = (A.clone().to(torch.uint8) & 1)
n = A.shape[0]
I = torch.eye(n, dtype=torch.uint8, device=A.device)
aug = torch.cat([A, I], dim=1)
row = 0
for col in range(n):
pivot = None
for r in range(row, n):
if aug[r, col].item():
pivot = r
break
if pivot is None:
raise ValueError("Matrix is not invertible over GF(2).")
if pivot != row:
tmp = aug[row].clone()
aug[row] = aug[pivot]
aug[pivot] = tmp
for r in range(n):
if r != row and aug[r, col].item():
aug[r] ^= aug[row]
row += 1
left = aug[:, :n]
if not torch.equal(left, I):
raise RuntimeError("GF(2) inverse construction failed.")
return aug[:, n:]
def make_random_invertible_binary_matrix(
code_bits: int,
seed: int = 0,
min_row_weight: int = 4,
min_col_weight: int = 4,
device: str = "cpu",
):
"""
Random dense-ish invertible matrix A in GL(code_bits, 2)
and random shift b in {0,1}^{code_bits}.
min_row_weight / min_col_weight are optional constraints
to avoid trivial near-permutation matrices.
"""
g = torch.Generator(device=device if device != "cpu" else "cpu")
g.manual_seed(seed)
while True:
A = torch.randint(
0, 2, (code_bits, code_bits),
generator=g, dtype=torch.uint8, device=device
)
if gf2_rank(A) != code_bits:
continue
if min_row_weight is not None:
if not torch.all(A.sum(dim=1) >= min_row_weight):
continue
if min_col_weight is not None:
if not torch.all(A.sum(dim=0) >= min_col_weight):
continue
b = torch.randint(
0, 2, (code_bits,),
generator=g, dtype=torch.uint8, device=device
)
return A, b
class BVVConfig(PretrainedConfig):
model_type = "model_binary_affine_code_n_layer_32"
def __init__(
self,
vocab_size=65536,
code_bits=None,
n_embed=16, # backward-compatible alias
d_model=1024,
n_head=32,
n_layer=32,
block_size=1024,
dropout=0.00,
layer_norm_eps=1e-5,
initializer_range=0.02,
pad_token_id=57344,
pad_id=57344,
bos_token_id=None,
eos_token_id=None,
tie_word_embeddings=False,
use_cache=False,
# affine code params
code_seed=12345,
code_matrix=None, # optional explicit A
code_shift=None, # optional explicit b
min_row_weight=4,
min_col_weight=4,
zero_pad_code=True,
**kwargs,
):
if pad_token_id is None:
pad_token_id = 57344 if pad_id is None else pad_id
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
use_cache=use_cache,
**kwargs,
)
if code_bits is None:
code_bits = n_embed
if vocab_size != (1 << code_bits):
raise ValueError(
f"For the exact minimal-code experiment require "
f"vocab_size == 2**code_bits, got vocab_size={vocab_size}, code_bits={code_bits}."
)
if d_model % code_bits != 0:
raise ValueError(f"d_model ({d_model}) must be divisible by code_bits ({code_bits})")
if d_model % n_head != 0:
raise ValueError(f"d_model ({d_model}) must be divisible by n_head ({n_head})")
if (d_model // n_head) % 2 != 0:
raise ValueError("head_dim must be even for rotary embeddings")
self.vocab_size = vocab_size
self.block_size = block_size
self.max_position_embeddings = block_size
self.code_bits = code_bits
self.n_embed = code_bits # alias for old scripts
self.d_model = d_model
self.n_head = n_head
self.n_layer = n_layer
self.dropout = dropout
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.scale = d_model // code_bits
# code params
self.code_seed = code_seed
self.code_matrix = code_matrix
self.code_shift = code_shift
self.min_row_weight = min_row_weight
self.min_col_weight = min_col_weight
self.zero_pad_code = zero_pad_code
# backward compatibility
self.pad_id = pad_token_id
class BinaryAffineCodeInput(nn.Module):
"""
Table-free token input:
token id -> 16-bit code -> affine GF(2) mixing -> tiled lift to d_model
No trainable parameters.
"""
def __init__(self, config: BVVConfig):
super().__init__()
self.vocab_size = config.vocab_size
self.code_bits = config.code_bits
self.d_model = config.d_model
self.scale = config.scale
self.pad_token_id = config.pad_token_id
self.zero_pad_code = config.zero_pad_code
self.register_buffer(
"bit_positions",
torch.arange(self.code_bits, dtype=torch.long),
persistent=False,
)
if config.code_matrix is None:
A, _ = make_random_invertible_binary_matrix(
code_bits=self.code_bits,
seed=config.code_seed,
min_row_weight=config.min_row_weight,
min_col_weight=config.min_col_weight,
device="cpu",
)
else:
A = torch.tensor(config.code_matrix, dtype=torch.uint8)
if A.shape != (self.code_bits, self.code_bits):
raise ValueError(
f"code_matrix must have shape {(self.code_bits, self.code_bits)}, got {tuple(A.shape)}"
)
if gf2_rank(A) != self.code_bits:
raise ValueError("Provided/generated code_matrix is not invertible over GF(2).")
# --- choose b so that pad_token_id maps to 0^K ---
if config.code_shift is None:
pad = torch.tensor(config.pad_token_id, dtype=torch.long, device=A.device)
bit_positions = torch.arange(self.code_bits, dtype=torch.long, device=A.device)
# LSB-first, same convention as ids_to_bits()
pad_bits = ((pad >> bit_positions) & 1).to(torch.float32) # [K]
# because forward uses: codes = bits @ A.T xor b
b = torch.remainder(pad_bits @ A.to(torch.float32).T, 2.0).to(torch.uint8)
else:
b = torch.tensor(config.code_shift, dtype=torch.uint8)
if b.shape != (self.code_bits,):
raise ValueError(
f"code_shift must have shape {(self.code_bits,)}, got {tuple(b.shape)}"
)
self.register_buffer("A_gf2", (A & 1).contiguous(), persistent=True)
self.register_buffer("b_gf2", (b & 1).contiguous(), persistent=True)
def ids_to_bits(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
input_ids: [B, T] int64
returns: [B, T, K] float32 in {0,1}
"""
if input_ids.dtype != torch.long:
input_ids = input_ids.long()
if input_ids.min().item() < 0 or input_ids.max().item() >= self.vocab_size:
raise ValueError(
f"input_ids out of range: min={input_ids.min().item()}, "
f"max={input_ids.max().item()}, vocab_size={self.vocab_size}"
)
bits = ((input_ids.unsqueeze(-1) >> self.bit_positions) & 1).to(torch.float32)
return bits
def mix_bits_affine(self, bits: torch.Tensor) -> torch.Tensor:
"""
bits: [B, T, K] float32 with entries 0/1
returns c = bits @ A^T + b mod 2
"""
A = self.A_gf2.to(device=bits.device, dtype=torch.float32)
b = self.b_gf2.to(device=bits.device, dtype=torch.float32)
mixed = torch.remainder(torch.matmul(bits, A.T) + b, 2.0)
return mixed
def encode_bits(self, input_ids: torch.Tensor) -> torch.Tensor:
bits = self.ids_to_bits(input_ids)
codes = self.mix_bits_affine(bits)
return codes
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
returns x: [B, T, d_model]
"""
codes = self.encode_bits(input_ids) # [B, T, K]
x = codes.repeat(1, 1, self.scale) # [B, T, d_model]
# Optional: keep pad positions exactly zero in the continuous input tensor
if self.zero_pad_code and self.pad_token_id is not None:
pad_mask = input_ids.eq(self.pad_token_id).unsqueeze(-1)
x = x.masked_fill(pad_mask, 0.0)
return x
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_head, dropout=0.0):
super().__init__()
assert d_model % n_head == 0
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
assert self.head_dim % 2 == 0, "head_dim must be even for rotary embeddings"
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, freqs_cis, mask=None):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim)
k = self.k_proj(x).view(B, T, self.n_head, self.head_dim)
v = self.v_proj(x).view(B, T, self.n_head, self.head_dim)
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
q = q.transpose(1, 2) # (B, n_head, T, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
attn_scores = attn_scores + mask
attn_probs = F.softmax(attn_scores.float(), dim=-1).type_as(q)
attn_probs = self.dropout(attn_probs)
out = torch.matmul(attn_probs, v)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
class TransformerMLP(nn.Module):
def __init__(self, d_model, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_head, dropout=0.0, layer_norm_eps=1e-5):
super().__init__()
self.self_attn = MultiHeadSelfAttention(d_model, n_head, dropout=dropout)
self.mlp = TransformerMLP(d_model, dropout=dropout)
self.input_layernorm = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(d_model, eps=layer_norm_eps)
def forward(self, x, freqs_cis, mask=None):
x = x + self.self_attn(self.input_layernorm(x), freqs_cis, mask)
x = x + self.mlp(self.post_attention_layernorm(x))
return x
class BVVForCausalLM(PreTrainedModel, GenerationMixin):
config_class = BVVConfig
main_input_name = "input_ids"
def __init__(self, config: BVVConfig):
super().__init__(config)
# no nn.Embedding here
self.input_code = BinaryAffineCodeInput(config)
self.transformer_layers = nn.ModuleList([
TransformerBlock(
config.d_model,
n_head=config.n_head,
dropout=config.dropout,
layer_norm_eps=config.layer_norm_eps,
)
for _ in range(config.n_layer)
])
self.final_layernorm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size)
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(
config.d_model // config.n_head,
config.block_size,
),
persistent=False,
)
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
def get_input_embeddings(self):
# there is no embedding table
return None
def set_input_embeddings(self, value):
raise NotImplementedError("This model uses algorithmic binary token codes, not nn.Embedding.")
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
if input_ids.shape[1] > self.config.block_size:
input_ids = input_ids[:, -self.config.block_size:]
if attention_mask is not None:
attention_mask = attention_mask[:, -self.config.block_size:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
targets=None,
return_dict=None,
output_logits=True,
**kwargs,
):
if input_ids is None:
raise ValueError("input_ids must be provided")
if labels is not None and targets is not None:
raise ValueError("Use either labels or targets, not both.")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
B, T = input_ids.shape
if T > self.config.block_size:
raise ValueError(f"Sequence length {T} exceeds block_size {self.config.block_size}")
# ---- table-free input coding ----
x = self.input_code(input_ids)
# cast to model dtype if needed
x = x.to(dtype=self.final_layernorm.weight.dtype)
freqs_cis = self.freqs_cis[:T]
if not torch.is_complex(freqs_cis):
freqs_cis = torch.view_as_complex(freqs_cis.contiguous())
freqs_cis = freqs_cis.to(x.device)
mask = None
mask_value = torch.finfo(x.dtype).min
if T > 1:
mask = torch.full((1, 1, T, T), mask_value, device=x.device, dtype=x.dtype)
mask = torch.triu(mask, diagonal=1)
if attention_mask is not None:
if attention_mask.shape != (B, T):
raise ValueError(f"attention_mask must have shape {(B, T)}, got {tuple(attention_mask.shape)}")
pad_mask = torch.zeros((B, 1, 1, T), device=x.device, dtype=x.dtype)
pad_mask = pad_mask.masked_fill(attention_mask[:, None, None, :].eq(0), mask_value)
mask = pad_mask if mask is None else mask + pad_mask
for layer in self.transformer_layers:
x = layer(x, freqs_cis, mask)
x = self.final_layernorm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
if attention_mask is not None:
shift_labels = shift_labels.masked_fill(attention_mask[:, 1:].eq(0), -100)
if self.config.pad_token_id is not None:
shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_token_id, -100)
loss = F.cross_entropy(
shift_logits.float().view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
elif targets is not None:
legacy_targets = targets.contiguous()
if attention_mask is not None:
legacy_targets = legacy_targets.masked_fill(attention_mask.eq(0), -100)
if self.config.pad_token_id is not None:
legacy_targets = legacy_targets.masked_fill(legacy_targets == self.config.pad_token_id, -100)
loss = F.cross_entropy(
logits.float().view(-1, logits.size(-1)),
legacy_targets.view(-1),
ignore_index=-100,
)
if not return_dict:
if output_logits:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return (loss,) if loss is not None else tuple()
if output_logits:
return CausalLMOutput(loss=loss, logits=logits)
return CausalLMOutput(loss=loss, logits=None)
def generate(self, input_ids, max_new_tokens, attention_mask=None, do_sample=False):
was_training = self.training
self.eval()
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
with torch.no_grad():
for _ in range(max_new_tokens):
input_ids_cond = input_ids[:, -self.config.block_size:]
attention_mask_cond = attention_mask[:, -self.config.block_size:]
outputs = self(
input_ids=input_ids_cond,
attention_mask=attention_mask_cond,
return_dict=True
)
logits = outputs.logits[:, -1, :]
if do_sample:
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
attention_mask = torch.cat(
[attention_mask, torch.ones_like(next_token, dtype=attention_mask.dtype)],
dim=1
)
if was_training:
self.train()
return input_ids