import math import torch import torch.nn as nn from torch.nn import functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithCrossAttentions def gf2_rank(M: torch.Tensor) -> int: """ Rank over GF(2). M: [n, m] tensor with 0/1 entries. """ M = (M.clone().to(torch.uint8) & 1) n_rows, n_cols = M.shape rank = 0 for col in range(n_cols): pivot = None for r in range(rank, n_rows): if M[r, col].item(): pivot = r break if pivot is None: continue if pivot != rank: tmp = M[rank].clone() M[rank] = M[pivot] M[pivot] = tmp for r in range(n_rows): if r != rank and M[r, col].item(): M[r] ^= M[rank] rank += 1 if rank == n_rows: break return rank def gf2_inverse(A: torch.Tensor) -> torch.Tensor: """ Inverse over GF(2). A: [n, n] with 0/1 entries, invertible over GF(2). """ A = (A.clone().to(torch.uint8) & 1) n = A.shape[0] I = torch.eye(n, dtype=torch.uint8, device=A.device) aug = torch.cat([A, I], dim=1) row = 0 for col in range(n): pivot = None for r in range(row, n): if aug[r, col].item(): pivot = r break if pivot is None: raise ValueError("Matrix is not invertible over GF(2).") if pivot != row: tmp = aug[row].clone() aug[row] = aug[pivot] aug[pivot] = tmp for r in range(n): if r != row and aug[r, col].item(): aug[r] ^= aug[row] row += 1 left = aug[:, :n] if not torch.equal(left, I): raise RuntimeError("GF(2) inverse construction failed.") return aug[:, n:] def make_random_invertible_binary_matrix( code_bits: int, seed: int = 0, min_row_weight: int = 4, min_col_weight: int = 4, device: str = "cpu", ): """ Random dense-ish invertible matrix A in GL(code_bits, 2) and random shift b in {0,1}^{code_bits}. min_row_weight / min_col_weight are optional constraints to avoid trivial near-permutation matrices. """ g = torch.Generator(device=device if device != "cpu" else "cpu") g.manual_seed(seed) while True: A = torch.randint( 0, 2, (code_bits, code_bits), generator=g, dtype=torch.uint8, device=device ) if gf2_rank(A) != code_bits: continue if min_row_weight is not None: if not torch.all(A.sum(dim=1) >= min_row_weight): continue if min_col_weight is not None: if not torch.all(A.sum(dim=0) >= min_col_weight): continue b = torch.randint( 0, 2, (code_bits,), generator=g, dtype=torch.uint8, device=device ) return A, b class BVVConfig(PretrainedConfig): model_type = "model_binary_affine_code_n_layer_32" def __init__( self, vocab_size=65536, code_bits=None, n_embed=16, # backward-compatible alias d_model=1024, n_head=32, n_layer=32, block_size=1024, dropout=0.00, layer_norm_eps=1e-5, initializer_range=0.02, pad_token_id=57344, pad_id=57344, bos_token_id=None, eos_token_id=None, tie_word_embeddings=False, use_cache=False, # affine code params code_seed=12345, code_matrix=None, # optional explicit A code_shift=None, # optional explicit b min_row_weight=4, min_col_weight=4, zero_pad_code=True, **kwargs, ): if pad_token_id is None: pad_token_id = 57344 if pad_id is None else pad_id super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, use_cache=use_cache, **kwargs, ) if code_bits is None: code_bits = n_embed if vocab_size != (1 << code_bits): raise ValueError( f"For the exact minimal-code experiment require " f"vocab_size == 2**code_bits, got vocab_size={vocab_size}, code_bits={code_bits}." ) if d_model % code_bits != 0: raise ValueError(f"d_model ({d_model}) must be divisible by code_bits ({code_bits})") if d_model % n_head != 0: raise ValueError(f"d_model ({d_model}) must be divisible by n_head ({n_head})") if (d_model // n_head) % 2 != 0: raise ValueError("head_dim must be even for rotary embeddings") self.vocab_size = vocab_size self.block_size = block_size self.max_position_embeddings = block_size self.code_bits = code_bits self.n_embed = code_bits # alias for old scripts self.d_model = d_model self.n_head = n_head self.n_layer = n_layer self.dropout = dropout self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range self.scale = d_model // code_bits # code params self.code_seed = code_seed self.code_matrix = code_matrix self.code_shift = code_shift self.min_row_weight = min_row_weight self.min_col_weight = min_col_weight self.zero_pad_code = zero_pad_code # backward compatibility self.pad_id = pad_token_id class BinaryAffineCodeInput(nn.Module): """ Table-free token input: token id -> 16-bit code -> affine GF(2) mixing -> tiled lift to d_model No trainable parameters. """ def __init__(self, config: BVVConfig): super().__init__() self.vocab_size = config.vocab_size self.code_bits = config.code_bits self.d_model = config.d_model self.scale = config.scale self.pad_token_id = config.pad_token_id self.zero_pad_code = config.zero_pad_code self.register_buffer( "bit_positions", torch.arange(self.code_bits, dtype=torch.long), persistent=False, ) if config.code_matrix is None: A, _ = make_random_invertible_binary_matrix( code_bits=self.code_bits, seed=config.code_seed, min_row_weight=config.min_row_weight, min_col_weight=config.min_col_weight, device="cpu", ) else: A = torch.tensor(config.code_matrix, dtype=torch.uint8) if A.shape != (self.code_bits, self.code_bits): raise ValueError( f"code_matrix must have shape {(self.code_bits, self.code_bits)}, got {tuple(A.shape)}" ) if gf2_rank(A) != self.code_bits: raise ValueError("Provided/generated code_matrix is not invertible over GF(2).") # --- choose b so that pad_token_id maps to 0^K --- if config.code_shift is None: pad = torch.tensor(config.pad_token_id, dtype=torch.long, device=A.device) bit_positions = torch.arange(self.code_bits, dtype=torch.long, device=A.device) # LSB-first, same convention as ids_to_bits() pad_bits = ((pad >> bit_positions) & 1).to(torch.float32) # [K] # because forward uses: codes = bits @ A.T xor b b = torch.remainder(pad_bits @ A.to(torch.float32).T, 2.0).to(torch.uint8) else: b = torch.tensor(config.code_shift, dtype=torch.uint8) if b.shape != (self.code_bits,): raise ValueError( f"code_shift must have shape {(self.code_bits,)}, got {tuple(b.shape)}" ) self.register_buffer("A_gf2", (A & 1).contiguous(), persistent=True) self.register_buffer("b_gf2", (b & 1).contiguous(), persistent=True) def ids_to_bits(self, input_ids: torch.Tensor) -> torch.Tensor: """ input_ids: [B, T] int64 returns: [B, T, K] float32 in {0,1} """ if input_ids.dtype != torch.long: input_ids = input_ids.long() if input_ids.min().item() < 0 or input_ids.max().item() >= self.vocab_size: raise ValueError( f"input_ids out of range: min={input_ids.min().item()}, " f"max={input_ids.max().item()}, vocab_size={self.vocab_size}" ) bits = ((input_ids.unsqueeze(-1) >> self.bit_positions) & 1).to(torch.float32) return bits def mix_bits_affine(self, bits: torch.Tensor) -> torch.Tensor: """ bits: [B, T, K] float32 with entries 0/1 returns c = bits @ A^T + b mod 2 """ A = self.A_gf2.to(device=bits.device, dtype=torch.float32) b = self.b_gf2.to(device=bits.device, dtype=torch.float32) mixed = torch.remainder(torch.matmul(bits, A.T) + b, 2.0) return mixed def encode_bits(self, input_ids: torch.Tensor) -> torch.Tensor: bits = self.ids_to_bits(input_ids) codes = self.mix_bits_affine(bits) return codes def forward(self, input_ids: torch.Tensor) -> torch.Tensor: """ returns x: [B, T, d_model] """ codes = self.encode_bits(input_ids) # [B, T, K] x = codes.repeat(1, 1, self.scale) # [B, T, d_model] # Optional: keep pad positions exactly zero in the continuous input tensor if self.zero_pad_code and self.pad_token_id is not None: pad_mask = input_ids.eq(self.pad_token_id).unsqueeze(-1) x = x.masked_fill(pad_mask, 0.0) return x def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class MultiHeadSelfAttention(nn.Module): def __init__(self, d_model, n_head, dropout=0.0): super().__init__() assert d_model % n_head == 0 self.d_model = d_model self.n_head = n_head self.head_dim = d_model // n_head assert self.head_dim % 2 == 0, "head_dim must be even for rotary embeddings" self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.o_proj = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x, freqs_cis, mask=None): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_head, self.head_dim) k = self.k_proj(x).view(B, T, self.n_head, self.head_dim) v = self.v_proj(x).view(B, T, self.n_head, self.head_dim) q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) q = q.transpose(1, 2) # (B, n_head, T, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: attn_scores = attn_scores + mask attn_probs = F.softmax(attn_scores.float(), dim=-1).type_as(q) attn_probs = self.dropout(attn_probs) out = torch.matmul(attn_probs, v) out = out.transpose(1, 2).contiguous().view(B, T, C) return self.o_proj(out) class TransformerMLP(nn.Module): def __init__(self, d_model, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class TransformerBlock(nn.Module): def __init__(self, d_model, n_head, dropout=0.0, layer_norm_eps=1e-5): super().__init__() self.self_attn = MultiHeadSelfAttention(d_model, n_head, dropout=dropout) self.mlp = TransformerMLP(d_model, dropout=dropout) self.input_layernorm = nn.LayerNorm(d_model, eps=layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(d_model, eps=layer_norm_eps) def forward(self, x, freqs_cis, mask=None): x = x + self.self_attn(self.input_layernorm(x), freqs_cis, mask) x = x + self.mlp(self.post_attention_layernorm(x)) return x class BVVForCausalLM(PreTrainedModel, GenerationMixin): config_class = BVVConfig main_input_name = "input_ids" def __init__(self, config: BVVConfig): super().__init__(config) # no nn.Embedding here self.input_code = BinaryAffineCodeInput(config) self.transformer_layers = nn.ModuleList([ TransformerBlock( config.d_model, n_head=config.n_head, dropout=config.dropout, layer_norm_eps=config.layer_norm_eps, ) for _ in range(config.n_layer) ]) self.final_layernorm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size) self.register_buffer( "freqs_cis", precompute_freqs_cis( config.d_model // config.n_head, config.block_size, ), persistent=False, ) self.post_init() def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) def get_input_embeddings(self): # there is no embedding table return None def set_input_embeddings(self, value): raise NotImplementedError("This model uses algorithmic binary token codes, not nn.Embedding.") def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): if input_ids.shape[1] > self.config.block_size: input_ids = input_ids[:, -self.config.block_size:] if attention_mask is not None: attention_mask = attention_mask[:, -self.config.block_size:] return { "input_ids": input_ids, "attention_mask": attention_mask, } def forward( self, input_ids=None, attention_mask=None, labels=None, targets=None, return_dict=None, output_logits=True, **kwargs, ): if input_ids is None: raise ValueError("input_ids must be provided") if labels is not None and targets is not None: raise ValueError("Use either labels or targets, not both.") return_dict = return_dict if return_dict is not None else self.config.use_return_dict B, T = input_ids.shape if T > self.config.block_size: raise ValueError(f"Sequence length {T} exceeds block_size {self.config.block_size}") # ---- table-free input coding ---- x = self.input_code(input_ids) # cast to model dtype if needed x = x.to(dtype=self.final_layernorm.weight.dtype) freqs_cis = self.freqs_cis[:T] if not torch.is_complex(freqs_cis): freqs_cis = torch.view_as_complex(freqs_cis.contiguous()) freqs_cis = freqs_cis.to(x.device) mask = None mask_value = torch.finfo(x.dtype).min if T > 1: mask = torch.full((1, 1, T, T), mask_value, device=x.device, dtype=x.dtype) mask = torch.triu(mask, diagonal=1) if attention_mask is not None: if attention_mask.shape != (B, T): raise ValueError(f"attention_mask must have shape {(B, T)}, got {tuple(attention_mask.shape)}") pad_mask = torch.zeros((B, 1, 1, T), device=x.device, dtype=x.dtype) pad_mask = pad_mask.masked_fill(attention_mask[:, None, None, :].eq(0), mask_value) mask = pad_mask if mask is None else mask + pad_mask for layer in self.transformer_layers: x = layer(x, freqs_cis, mask) x = self.final_layernorm(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() if attention_mask is not None: shift_labels = shift_labels.masked_fill(attention_mask[:, 1:].eq(0), -100) if self.config.pad_token_id is not None: shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_token_id, -100) loss = F.cross_entropy( shift_logits.float().view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) elif targets is not None: legacy_targets = targets.contiguous() if attention_mask is not None: legacy_targets = legacy_targets.masked_fill(attention_mask.eq(0), -100) if self.config.pad_token_id is not None: legacy_targets = legacy_targets.masked_fill(legacy_targets == self.config.pad_token_id, -100) loss = F.cross_entropy( logits.float().view(-1, logits.size(-1)), legacy_targets.view(-1), ignore_index=-100, ) if not return_dict: if output_logits: output = (logits,) return ((loss,) + output) if loss is not None else output return (loss,) if loss is not None else tuple() if output_logits: return CausalLMOutput(loss=loss, logits=logits) return CausalLMOutput(loss=loss, logits=None) def generate(self, input_ids, max_new_tokens, attention_mask=None, do_sample=False): was_training = self.training self.eval() if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.long) with torch.no_grad(): for _ in range(max_new_tokens): input_ids_cond = input_ids[:, -self.config.block_size:] attention_mask_cond = attention_mask[:, -self.config.block_size:] outputs = self( input_ids=input_ids_cond, attention_mask=attention_mask_cond, return_dict=True ) logits = outputs.logits[:, -1, :] if do_sample: probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_token], dim=1) attention_mask = torch.cat( [attention_mask, torch.ones_like(next_token, dtype=attention_mask.dtype)], dim=1 ) if was_training: self.train() return input_ids