| |
| |
| |
|
|
| |
| from torch import nn |
| import torch |
| import numpy as np |
| import math |
|
|
| from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
|
|
|
| def gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4): |
| """ |
| Generates a [1, length, channels] timing signal consisting of sinusoids |
| Adapted from: |
| https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py |
| """ |
| position = np.arange(length) |
| num_timescales = channels // 2 |
| log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1)) |
| inv_timescales = min_timescale * np.exp(np.arange(num_timescales).astype(float) * -log_timescale_increment) |
| scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) |
|
|
| signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) |
| signal = np.pad(signal, [[0, 0], [0, channels % 2]], |
| 'constant', constant_values=[0.0, 0.0]) |
| signal = signal.reshape([1, length, channels]) |
|
|
| return torch.from_numpy(signal).type(torch.FloatTensor) |
|
|
| class ACT_basic(nn.Module): |
| def __init__(self,hidden_size): |
| super(ACT_basic, self).__init__() |
| self.sigma = nn.Sigmoid() |
| self.p = nn.Linear(hidden_size,1) |
| self.p.bias.data.fill_(1) |
| self.threshold = 1 - 0.1 |
| self.eps = 0.1 |
|
|
| def forward(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
| |
| |
| noisy_halting = False |
| if 'noisy_halting' in kwargs: |
| noisy_halting = kwargs['noisy_halting'] |
| kwargs.pop('noisy_halting') |
| halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| previous_state = torch.zeros_like(inputs).cuda() |
| step = 0 |
| |
| rest = None |
|
|
| while( ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any()): |
| |
| |
| |
|
|
| p = self.sigma(self.p(state)).squeeze(-1) |
| if noisy_halting and self.training: |
| p = p + torch.randn_like(p) * self.eps |
| |
| still_running = (halting_probability < 1.0).float() |
|
|
| |
| new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running |
|
|
| |
| still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running |
|
|
| |
| |
| halting_probability = halting_probability + p * still_running |
|
|
| |
| remainders = remainders + new_halted * (1 - halting_probability) |
|
|
| |
| halting_probability = halting_probability + new_halted * remainders |
|
|
| |
| n_updates = n_updates + still_running + new_halted |
|
|
| |
| |
| |
| |
| update_weights = p * still_running + new_halted * remainders |
|
|
| if(encoder_output): |
| state, _ = fn((state,encoder_output)) |
| else: |
| |
| state = fn(state, *args, **kwargs) |
| if isinstance(state, tuple): |
| rest = state[1:] |
| state = state[0] |
|
|
| |
| previous_state = ((state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1)))) |
| |
| |
| |
| step+=1 |
| if rest is None: |
| return previous_state, (remainders,n_updates) |
| else: |
| return (previous_state, *rest), (remainders, n_updates) |
|
|
|
|
| class ACT_constant_depth(): |
| def __init__(self): |
| super(ACT_constant_depth, self).__init__() |
|
|
| def __call__(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
| |
| |
| remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| previous_state = torch.zeros_like(inputs).cuda() |
| step = 0 |
| |
| rest = None |
|
|
| |
| while(step < max_hop): |
| print('constsant depth TRUE') |
| |
| |
| |
|
|
| if(encoder_output): |
| state, _ = fn((state,encoder_output)) |
| else: |
| |
| state = fn(state, *args, **kwargs) |
| if isinstance(state, tuple): |
| rest = state[1:] |
| state = state[0] |
| |
| |
| |
| |
| previous_state = state |
| |
| |
| |
| step+=1 |
| if rest is None: |
| return previous_state, (remainders,n_updates) |
| else: |
| return (previous_state, *rest), (remainders, n_updates) |
|
|
| class ACTForWholeARMT(nn.Module): |
| def __init__(self,hidden_size): |
| super(ACTForWholeARMT, self).__init__() |
| self.sigma = nn.Sigmoid() |
| self.p = nn.Linear(hidden_size,1) |
| self.p.bias.data.fill_(1) |
| self.threshold = 1 - 0.1 |
|
|
| def forward(self, *args, state, inputs, fn_no_update, fn_update, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
| |
| |
|
|
| halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| previous_state = torch.zeros_like(inputs).cuda() |
| step = 0 |
| |
| rest = None |
| while( ((halting_probability < self.threshold) & (n_updates < max_hop)).byte().any()): |
| |
| |
| |
|
|
| p = self.sigma(self.p(state)).squeeze(-1) |
| |
| still_running = (halting_probability < 1.0).float() |
|
|
| |
| new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running |
|
|
| |
| still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running |
|
|
| |
| |
| halting_probability = halting_probability + p * still_running |
|
|
| |
| remainders = remainders + new_halted * (1 - halting_probability) |
|
|
| |
| halting_probability = halting_probability + new_halted * remainders |
|
|
| |
| n_updates = n_updates + still_running + new_halted |
|
|
| |
| |
| |
| |
| update_weights = p * still_running + new_halted * remainders |
|
|
| if(encoder_output): |
| if ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any(): |
| state, _ = fn_no_update((state,encoder_output)) |
| else: |
| state, _ = fn_update((state, encoder_output)) |
| else: |
| |
| if ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any(): |
| state = fn_no_update(state, *args, **kwargs) |
| else: |
| state = fn_update(state, *args, **kwargs) |
| if isinstance(state, tuple): |
| rest = state[1:] |
| state = state[0] |
|
|
| |
| previous_state = ((state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1)))) |
| |
| |
| |
| step+=1 |
| if rest is None: |
| return previous_state, (remainders,n_updates) |
| else: |
| return (previous_state, *rest), (remainders, n_updates) |
|
|
| class ACTForWholeARMT_constant_depth(): |
| def __init__(self): |
| super(ACTForWholeARMT_constant_depth, self).__init__() |
|
|
|
|
| def __call__(self, *args, state, inputs, fn_no_update, fn_update, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
| print("\n\n\n\n\n\n\n\n\n\nCONSTANT DEPTH TRUE") |
| |
| |
| remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
| |
| n_updates = torch.full((inputs.shape[0],inputs.shape[1]), max_hop).cuda() |
| |
| previous_state = torch.zeros_like(inputs).cuda() |
| step = 0 |
| |
| rest = None |
| while(step < max_hop): |
| |
| |
| |
| if(encoder_output): |
| if (step < max_hop): |
| state, _ = fn_no_update((state,encoder_output)) |
| else: |
| state, _ = fn_update((state, encoder_output)) |
| else: |
| |
| if (step < max_hop): |
| state = fn_no_update(state, *args, **kwargs) |
| else: |
| state = fn_update(state, *args, **kwargs) |
| if isinstance(state, tuple): |
| rest = state[1:] |
| state = state[0] |
|
|
| |
| previous_state = state |
| |
| |
| |
| step+=1 |
| if rest is None: |
| return previous_state, (remainders,n_updates) |
| else: |
| return (previous_state, *rest), (remainders, n_updates) |
|
|
|
|
| class ACT_transformer(nn.Module): |
| def __init__(self, hidden_size, num_heads=4, num_transformer_layers=1, dropout=0.1): |
| super(ACT_transformer, self).__init__() |
| |
| transformer_layer = TransformerEncoderLayer( |
| d_model=hidden_size, |
| nhead=num_heads, |
| dim_feedforward=hidden_size, |
| dropout=dropout, |
| norm_first=True |
| ) |
| self.transformer = TransformerEncoder(transformer_layer, |
| num_layers=num_transformer_layers) |
| |
| |
| self.logit_ff = nn.Linear(hidden_size, 1) |
| self.logit_ff.bias.data.fill_(1) |
| |
| |
| self.sigma = nn.Sigmoid() |
| self.threshold = 1 - 0.1 |
|
|
| def generate_causal_mask(self, seq_len): |
| mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) |
| mask = mask.masked_fill(mask == 1, float('-inf')) |
| return mask |
|
|
| def forward(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
| batch_size, seq_len, hidden_size = inputs.shape |
| halting_probability = torch.zeros(batch_size, seq_len).cuda() |
| remainders = torch.zeros(batch_size, seq_len).cuda() |
| n_updates = torch.zeros(batch_size, seq_len).cuda() |
| previous_state = torch.zeros_like(inputs).cuda() |
| step = 0 |
| rest = None |
|
|
| causal_mask = self.generate_causal_mask(seq_len).cuda() |
|
|
| while ((halting_probability < self.threshold) & (n_updates < max_hop)).byte().any(): |
| state_transformed = self.transformer( |
| state.permute(1, 0, 2), |
| mask=causal_mask |
| ) |
| state_transformed = state_transformed.permute(1, 0, 2) |
|
|
| |
| p = self.sigma(self.logit_ff(state_transformed)).squeeze(-1) |
|
|
| |
| still_running = (halting_probability < 1.0).float() |
| new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running |
| still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running |
| halting_probability = halting_probability + p * still_running |
| remainders = remainders + new_halted * (1 - halting_probability) |
| halting_probability = halting_probability + new_halted * remainders |
| n_updates = n_updates + still_running + new_halted |
| update_weights = p * still_running + new_halted * remainders |
|
|
| if encoder_output is not None: |
| state, _ = fn((state, encoder_output)) |
| else: |
| state = fn(state, *args, **kwargs) |
| if isinstance(state, tuple): |
| rest = state[1:] |
| state = state[0] |
|
|
| previous_state = ( |
| (state * update_weights.unsqueeze(-1)) + |
| (previous_state * (1 - update_weights.unsqueeze(-1))) |
| ) |
| step += 1 |
|
|
| if rest is None: |
| return previous_state, (remainders, n_updates) |
| else: |
| return (previous_state, *rest), (remainders, n_updates) |
|
|
|
|
| |
| import math |
| import torch |
| from torch.nn import CrossEntropyLoss |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| from transformers.cache_utils import Cache, DynamicCache |
| from torch.nn.functional import relu as r |
| import torch.nn.functional as F |
| import os |
| from dataclasses import dataclass |
| from transformers.modeling_outputs import ModelOutput |
|
|
| @dataclass |
| class ARMTOutput(ModelOutput): |
| """ |
| Custom output format for ARMT with all necessary fields. |
| This replaces Munch in the original implementation. |
| """ |
| logits: torch.FloatTensor = None |
| loss: torch.FloatTensor = None |
| hidden_states: torch.FloatTensor = None |
| attentions: tuple = None |
| past_key_values: tuple = None |
| remainders: torch.FloatTensor = None |
| n_updates: torch.FloatTensor = None |
| ce_loss: torch.FloatTensor = None |
|
|
| |
| try: |
| from cut_cross_entropy import linear_cross_entropy |
| CUT_CROSS_ENTROPY_AVAILABLE = True |
| except ImportError: |
| CUT_CROSS_ENTROPY_AVAILABLE = False |
| print("Warning: cut_cross_entropy not available, falling back to standard CrossEntropyLoss") |
|
|
| |
| try: |
| from baselines.rwkv.language_modeling import RWKVModel |
| RWKV_imported = True |
| except ImportError: |
| print("*** Can't import RWKV model ***") |
| RWKV_imported = False |
| def dpfp(x, nu=1): |
| x = torch.cat([r(x), r(-x)], dim=-1) |
| x_rolled = torch.cat([x.roll(shifts=j, dims=-1) |
| for j in range(1,nu+1)], dim=-1) |
| x_repeat = torch.cat([x] * nu, dim=-1) |
| return x_repeat * x_rolled |
|
|
| class DPFP: |
| def __init__(self, nu): |
| self.nu = nu |
| |
| def __call__(self, x): |
| nu = self.nu |
| x = torch.cat([r(x), r(-x)], dim=-1) |
| x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1,nu+1)], dim=-1) |
| x_repeat = torch.cat([x] * nu, dim=-1) |
| return x_repeat * x_rolled |
| def attn_mask_to_4d(attn_mask, upper, query_len): |
| if attn_mask is None: |
| return None |
| seg_len = attn_mask.size(-1) |
| if upper: |
| tri = torch.triu(torch.ones(query_len, seg_len, dtype=attn_mask.dtype, device=attn_mask.device)) |
| else: |
| tri = torch.tril(torch.ones(query_len, seg_len, dtype=attn_mask.dtype, device=attn_mask.device)) |
|
|
| mask = torch.einsum('bj,ij->bij', attn_mask, tri) |
| mask = mask.unsqueeze(1) |
| return mask |
|
|
| def invert_attn_mask(attn_mask, dtype): |
| min_dtype = torch.finfo(dtype).min |
| |
| one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device) |
| new_mask = (one - attn_mask) * min_dtype |
| return new_mask |
|
|
|
|
|
|
| class AssociativeLayerWrapper(torch.nn.Module): |
|
|
| def __init__(self, layer, d_model, num_mem_tokens, d_mem, n_heads=1, correction=True, info=None, use_denom=True, gating=False) -> None: |
| super().__init__() |
| self.info = info |
| self.seg_num = 0 |
| self.d_model = d_model |
| self.num_mem_tokens = num_mem_tokens |
| self.d_mem = d_mem |
| self.n_heads = n_heads |
| self.gating = gating |
| nu = 3 |
| self.d_key = 2 * nu * d_mem |
|
|
| assert self.d_mem % n_heads == 0 and self.d_model % n_heads == 0 |
|
|
| self.phi = DPFP(nu) |
| |
| |
|
|
| self.use_denom = use_denom |
|
|
| |
| layer_dtype = next(layer.parameters()).dtype |
| |
| self.W_mq = torch.nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) |
| |
| self.W_mk = torch.nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) |
| self.W_mv = torch.nn.Linear(d_model, d_model, bias=False, dtype=layer_dtype) |
| if gating: |
| self.W_mb = torch.nn.Linear(d_model, d_model, dtype=layer_dtype) |
| else: |
| self.W_mb = torch.nn.Linear(d_model, n_heads, dtype=layer_dtype) |
| torch.nn.init.zeros_(self.W_mv.weight) |
| s = 1/math.sqrt(d_model) |
| |
| |
| |
|
|
|
|
| |
|
|
| self.layer = layer |
| |
| self.generate_mode = False |
| self.first_seg = True |
| self.correction = correction |
| |
| self.zero_mem() |
|
|
| def _to_heads(self, x): |
| bsz, seq_len, d_model = x.shape |
| x = x.reshape(bsz, seq_len, self.n_heads, d_model // self.n_heads) |
| x = x.permute(0, 2, 1, 3) |
| return x |
| |
| def _from_heads(self, x): |
| bsz, n_heads, seq_len, d_head = x.shape |
| x = x.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * d_head) |
| return x |
| def associate(self, hidden_states): |
| bsz, seq_len, d_model = hidden_states.shape |
|
|
| self.W_mem = self.W_mem.to(hidden_states.device) |
| if self.use_denom: |
| self.z = self.z.to(hidden_states.device) |
|
|
| q = self._to_heads(self.W_mq(hidden_states)) |
| mq = self.phi(q) |
| mq = F.normalize(mq, dim=-1, p=2.0) |
| |
| |
| num = torch.einsum('ihjk,ihkt->ihjt', mq, self.W_mem) |
| if self.use_denom: |
| denom = torch.einsum("ihk,ihjk->ihj", self.z, mq)[..., None] + 1e-5 |
| hidden_states = num / denom |
| else: |
| hidden_states = num |
| hidden_states = self._from_heads(hidden_states) |
| return hidden_states |
| |
| def forward(self, hidden_states, *args, **kwargs): |
| if not self.first_seg: |
| hidden_states = self.associate( |
| |
| hidden_states |
| |
| ) + hidden_states |
| out = self.layer(hidden_states, *args, **kwargs) |
| if not self.generate_mode: |
| |
| |
| if isinstance(out, tuple): |
| mem_tokens = out[0][:, -self.num_mem_tokens:] |
| else: |
| mem_tokens = out[:, -self.num_mem_tokens:] |
|
|
| self.update_mem(mem_tokens) |
| return out |
| |
| def forward_no_update(self, hidden_states, *args, **kwargs): |
| if not self.first_seg: |
| hidden_states = self.associate( |
| |
| hidden_states |
| |
| )+ hidden_states |
| out = self.layer(hidden_states, *args, **kwargs) |
| return out |
| |
| def forward_no_update(self, hidden_states, *args, **kwargs): |
| if not self.first_seg: |
| hidden_states = self.associate( |
| |
| hidden_states |
| |
| ) + hidden_states |
| out = self.layer(hidden_states, *args, **kwargs) |
| return out |
|
|
| def update_mem(self, mem_tokens): |
|
|
| self.W_mem = self.W_mem.to(mem_tokens.device) |
| if self.use_denom: |
| self.z = self.z.to(mem_tokens.device) |
| k = self._to_heads(self.W_mk(mem_tokens)) |
| mk = self.phi(k) |
| mk = F.normalize(mk, dim=-1, p=2.0) |
|
|
| new_mv = self._to_heads(self.W_mv(mem_tokens)) |
| if not self.first_seg: |
| num = torch.einsum('ihjk,ihkt->ihjt', mk, self.W_mem) |
| if self.use_denom: |
| denom = torch.einsum("ihj,ihkj->ihk", self.z, mk)[..., None] + 1e-5 |
| prev_mv = num / denom |
| if self.correction: |
| new_info_coef = (1 - denom / (torch.linalg.norm(mk, dim=-1) ** 2)[..., None]) |
| new_info_coef = torch.clip(new_info_coef, 0, 1).detach() |
| else: |
| new_info_coef = 1 |
| else: |
| prev_mv = num |
| else: |
| prev_mv = torch.zeros_like(new_mv, device=new_mv.device) |
| new_info_coef = 1 |
| |
| mv = new_mv - prev_mv |
|
|
| |
| |
| |
| |
|
|
| mb = self._to_heads(torch.sigmoid(self.W_mb(mem_tokens))) |
|
|
| einop = f"ihjk,ihjt,ihj{'t' if self.gating else 'x'}->ihkt" |
| associations = torch.einsum(einop, mk, mv, mb) |
|
|
| self.W_mem = self.W_mem + associations |
|
|
| if self.use_denom: |
| self.z = self.z + (new_info_coef*mk).sum(dim=-2) |
| |
| self.seg_num += 1 |
| self.first_seg = False |
|
|
| def freeze_mem(self): |
| self.W_mb.weight.requires_grad = False |
| self.W_mb.bias.requires_grad = False |
| self.W_mq.weight.requires_grad = False |
| self.W_mk.weight.requires_grad = False |
| self.W_mv.weight.requires_grad = False |
|
|
| def zero_mem(self): |
| self.first_seg = True |
| |
| layer_dtype = next(self.layer.parameters()).dtype |
| self.W_mem = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, self.d_model // self.n_heads, dtype=layer_dtype) |
| self.W_mem.requires_grad_(False) |
| if self.use_denom: |
| self.z = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, dtype=layer_dtype) |
| self.z.requires_grad_(False) |
| self.seg_num = 0 |
|
|
| def detach_mem(self): |
| self.W_mem = self.W_mem.detach() |
| if self.use_denom: |
| self.z = self.z.detach() |
|
|
|
|
|
|
|
|
| class AdaptiveAssociativeLayerWrapper(AssociativeLayerWrapper): |
| def __init__(self, |
| layer, |
| d_model, |
| num_mem_tokens, |
| d_mem, |
| max_hop, |
| n_heads=1, |
| correction=True, |
| info=None, |
| use_denom=True, |
| gating=False, |
| constant_depth=False, |
| |
| ) -> None: |
| super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) |
| self.act = ACT_basic(d_model) if not constant_depth else ACT_constant_depth() |
| self.depth = max_hop |
| self.max_length = 1024 |
|
|
| self.timing_signal = gen_timing_signal(self.max_length, d_model) |
| |
| self.position_signal = gen_timing_signal(self.depth, d_model) |
|
|
| self.remainders = torch.zeros(1,) |
| self.n_updates = torch.zeros(1,) |
| self.segments_passed = torch.zeros(1,) |
|
|
| def associate(self, hidden_states): |
| self.remainders = self.remainders.to(hidden_states.device) |
| self.n_updates = self.n_updates.to(hidden_states.device) |
| self.segments_passed = self.segments_passed.to(hidden_states.device) |
| out, (remainders, n_updates) = self.act( |
| state=hidden_states, |
| inputs=hidden_states, |
| fn=super().associate, |
| time_enc=self.timing_signal, |
| pos_enc=self.position_signal, |
| max_hop=self.depth |
| ) |
| |
| self.remainders = self.remainders + remainders.mean() |
| self.n_updates = self.n_updates + n_updates.mean() |
| self.segments_passed = self.segments_passed + 1 |
| return out |
| |
| def zero_mem(self): |
| self.remainders = torch.zeros(1,) |
| self.n_updates = torch.zeros(1,) |
| self.segments_passed = torch.zeros(1,) |
| return super().zero_mem() |
| |
| def detach_mem(self): |
| self.remainders = torch.zeros(1,) |
| self.n_updates = torch.zeros(1,) |
| self.segments_passed = torch.zeros(1,) |
| return super().detach_mem() |
|
|
|
|
|
|
| class AdaptiveAssociativeLayerWrapper2(AssociativeLayerWrapper): |
| def __init__(self, |
| layer, |
| d_model, |
| num_mem_tokens, |
| d_mem, |
| max_hop, |
| n_heads=1, |
| correction=True, |
| info=None, |
| use_denom=True, |
| gating=False, |
| act_format='linear', |
| noisy_halting=False, |
| constant_depth=False, |
| ) -> None: |
| super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) |
|
|
| if act_format=='transformer': |
| self.act = ACT_transformer(d_model) |
| elif constant_depth: |
| self.act = ACT_constant_depth() |
| elif act_format == 'linear': |
| self.act = ACT_basic(d_model) |
| else: |
| raise NotImplemetedError |
|
|
| self.depth = max_hop |
| self.max_length = 1024 |
|
|
| self.noisy_halting = noisy_halting |
|
|
| self.timing_signal = gen_timing_signal(self.max_length, d_model) |
| |
| self.position_signal = gen_timing_signal(self.depth, d_model) |
|
|
| self.remainders = torch.zeros(1,) |
| self.n_updates = torch.zeros(1,) |
| self.segments_passed = torch.zeros(1,) |
|
|
| def forward(self, hidden_states, *args, **kwargs): |
| self.remainders = self.remainders.to(hidden_states.device) |
| self.n_updates = self.n_updates.to(hidden_states.device) |
| self.segments_passed = self.segments_passed.to(hidden_states.device) |
|
|
| if self.noisy_halting: |
| kwargs['noisy_halting'] = self.noisy_halting |
| fwd = super().forward_no_update |
| out, (remainders, n_updates) = self.act( |
| *args, |
| state=hidden_states, |
| inputs=hidden_states, |
| fn=fwd, |
| time_enc=self.timing_signal, |
| pos_enc=self.position_signal, |
| max_hop=self.depth, |
| **kwargs |
| ) |
| if not self.generate_mode: |
| mem_tokens = out[0][:, -self.num_mem_tokens:] |
| |
| self.update_mem(mem_tokens) |
| self.first_seg = False |
| self.remainders = self.remainders + remainders.mean() |
| self.n_updates = self.n_updates + n_updates.mean() |
| self.segments_passed = self.segments_passed + 1 |
| return out |
|
|
| |
| def zero_mem(self): |
| self.remainders = torch.zeros(1,) |
| self.n_updates = torch.zeros(1,) |
| self.segments_passed = torch.zeros(1,) |
| return super().zero_mem() |
| |
| def detach_mem(self): |
| self.remainders = torch.zeros(1,) |
| self.n_updates = torch.zeros(1,) |
| self.segments_passed = torch.zeros(1,) |
| return super().detach_mem() |
|
|
|
|
| class AdaptiveAssociativeLayerWrapper(AssociativeLayerWrapper): |
| def __init__(self, |
| layer, |
| d_model, |
| num_mem_tokens, |
| d_mem, |
| max_hop, |
| n_heads=1, |
| correction=True, |
| info=None, |
| use_denom=True, |
| gating=False, |
| |
| ) -> None: |
| super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) |
| self.act = ACT_basic(d_model) |
| self.depth = max_hop |
| self.max_length = 1024 |
|
|
| self.timing_signal = gen_timing_signal(self.max_length, d_model) |
| |
| self.position_signal = gen_timing_signal(self.depth, d_model) |
|
|
| self.remainders = torch.zeros(1,) |
| self.n_updates = torch.zeros(1,) |
| self.segments_passed = torch.zeros(1,) |
|
|
| def associate(self, hidden_states): |
| self.remainders = self.remainders.to(hidden_states.device) |
| self.n_updates = self.n_updates.to(hidden_states.device) |
| self.segments_passed = self.segments_passed.to(hidden_states.device) |
| out, (remainders, n_updates) = self.act( |
| state=hidden_states, |
| inputs=hidden_states, |
| fn=super().associate, |
| time_enc=self.timing_signal, |
| pos_enc=self.position_signal, |
| max_hop=self.depth |
| ) |
| |
| self.remainders = self.remainders + remainders |
| self.n_updates = self.n_updates + n_updates |
| self.segments_passed = self.segments_passed + 1 |
| return out |
| |
| def zero_mem(self): |
| self.remainders = torch.zeros(1,) |
| self.n_updates = torch.zeros(1,) |
| self.segments_passed = torch.zeros(1,) |
| return super().zero_mem() |
| |
|
|
|
|
| class AssociativeMemoryCell(torch.nn.Module): |
| def __init__(self, |
| base_model, |
| num_mem_tokens, |
| d_mem, |
| layers_attr: str = 'model.layers', |
| wrap_pos=False, |
| correction=True, |
| n_heads=1, |
| use_denom=True, |
| gating=False, |
| freeze_mem=False, |
| act_on=False, |
| max_hop=4, |
| act_type='layer', |
| act_format='linear', |
| noisy_halting=False, |
| constant_depth=False, |
| attend_to_previous_input=False, |
| use_sink=False, |
| **rmt_config |
| ): |
| super().__init__() |
| self.model = base_model |
| |
| self.attend_to_previous_input = attend_to_previous_input |
| self.previous_input = None |
| self.use_sink = use_sink |
| |
| self.RWKV_ARMT = isinstance(self.model, RWKVModel) if RWKV_imported else False |
|
|
| self.num_mem_tokens = num_mem_tokens |
| self.d_mem = d_mem |
| self.d_model = base_model.get_input_embeddings().embedding_dim |
| self.W_mem = [] |
|
|
| self.constant_depth = constant_depth |
|
|
| self.layers_attrs = layers_attr.split('.') |
|
|
| def _get_layers_from_model(model_root): |
| layers_obj = model_root |
| for attr in self.layers_attrs: |
| layers_obj = getattr(layers_obj, attr) |
| return layers_obj |
|
|
| layers = _get_layers_from_model(self.model) |
| |
| for i in range(len(layers)): |
| kw = dict( |
| layer=layers[i], |
| d_model=self.d_model, |
| num_mem_tokens=self.num_mem_tokens, |
| d_mem=self.d_mem, |
| correction=correction, |
| info={'layer': i}, |
| n_heads=n_heads, |
| use_denom=use_denom, |
| gating=gating, |
| ) |
| if act_on and act_type != 'model': |
| kw['act_format'] = act_format |
| if act_on and act_type == 'model' and act_format != 'linear': |
| raise NotImplementedError |
| if act_on and (act_type != 'model'): |
| kw['max_hop'] = max_hop |
| kw['constant_depth'] = self.constant_depth |
| kw['act_format'] = act_format |
| if act_on and noisy_halting: |
| kw['noisy_halting'] = noisy_halting |
| if not act_on: |
| layers[i] = AssociativeLayerWrapper(**kw) |
| elif act_type == 'associative': |
| layers[i] = AdaptiveAssociativeLayerWrapper(**kw) |
| elif act_type == 'layer': |
| layers[i] = AdaptiveAssociativeLayerWrapper2(**kw) |
| elif act_type == 'model': |
| layers[i] = AssociativeLayerWrapper(**kw) |
| else: |
| raise f'Unknown ACT type: {act_type}' |
|
|
| if act_type == 'model': |
| self.act = ACTForWholeARMT(self.d_model) if not self.constant_depth else ACTForWholeARMT_constant_depth() |
| self.depth = max_hop |
| self.max_length = 1024 |
| self.timing_signal = gen_timing_signal(self.max_length, self.d_model) |
| self.position_signal = gen_timing_signal(self.depth, self.d_model) |
| self.act_type = act_type |
|
|
| self.create_memory(num_mem_tokens) |
| self.wrap_pos = wrap_pos |
| self.act_on = act_on |
| if wrap_pos: |
| self.wrap_positional_embeddings(num_mem_tokens) |
| |
| if freeze_mem: |
| for layer in _get_layers_from_model(self.model): |
| layer.freeze_mem() |
|
|
| |
| self.get_layers = lambda: _get_layers_from_model(self.model) |
| |
| def generate_mode(self, is_on): |
| for layer in self.get_layers(): |
| layer.generate_mode = is_on |
| |
| def create_memory(self, num_mem_tokens): |
| self.num_mem_tokens = num_mem_tokens |
| embeddings = self.model.get_input_embeddings() |
| memory_dim = getattr(self.model.config, 'n_embd', self.model.config.hidden_size) |
| memory_weights = torch.randn((num_mem_tokens, memory_dim), device=embeddings.weight.data.device, dtype=embeddings.weight.data.dtype) * embeddings.weight.data.std() |
|
|
| self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True)) |
| if self.use_sink: |
| self.sink = torch.nn.Parameter(torch.randn((1, memory_dim), device=embeddings.weight.data.device, dtype=embeddings.weight.data.dtype), requires_grad=True) |
|
|
|
|
| def wrap_positional_embeddings(self, num_mem_tokens): |
| num_pos_embs, emb_dim = self.model.transformer.wpe.weight.shape |
| prev_embs = self.model.transformer.wpe.weight.detach() |
| self.model.transformer.wpe = torch.nn.Embedding(num_mem_tokens + num_pos_embs, emb_dim) |
|
|
| new_num_pos = num_pos_embs + num_mem_tokens |
| with torch.no_grad(): |
| self.model.transformer.wpe.weight[:len(self.model.transformer.wpe.weight)-num_mem_tokens] = prev_embs |
| for layer in self.model.transformer.h: |
| layer.layer.attn.bias = torch.tril(torch.ones((new_num_pos, new_num_pos), dtype=torch.uint8)).view( |
| 1, 1, new_num_pos, new_num_pos |
| ) |
|
|
| def set_memory(self, input_shape): |
| memory = self.memory.repeat(input_shape[0], 1, 1) |
| if self.use_sink: |
| sink = self.sink.repeat(input_shape[0], 1, 1) |
| else: |
| sink = None |
| return memory, sink |
|
|
| def zero_mem(self): |
| for layer in self.get_layers(): |
| layer.zero_mem() |
| self.previous_input = None |
| |
| def detach_mem(self): |
| for layer in self.get_layers(): |
| layer.detach_mem() |
| pass |
|
|
| def forward(self, input_ids, labels=None, labels_mask=None, zero_mem=False, attention_mask=None, **kwargs): |
| if self.act_type != 'model': |
| out = self.forward_with_update(input_ids, labels, labels_mask, zero_mem, attention_mask=attention_mask, **kwargs) |
| else: |
| seg_kwargs = self.process_input(input_ids=input_ids, |
| labels=labels, |
| labels_mask=labels_mask, |
| zero_mem=zero_mem, |
| attention_mask=attention_mask, |
| **kwargs |
| ) |
| out = self.gptneox_forward_act(**seg_kwargs) |
| out = self.process_output(out, labels=labels, labels_mask=labels_mask) |
| return out |
|
|
| def forward_with_update(self, input_ids, labels=None, labels_mask=None, zero_mem=False, **kwargs): |
| current_input_ids = input_ids.clone() |
| if self.attend_to_previous_input and self.previous_input is not None: |
| input_ids = torch.cat([self.previous_input, input_ids], dim=1) |
| |
| if zero_mem: |
| self.zero_mem() |
|
|
| seg_kwargs = self.process_input(input_ids, **kwargs) |
| |
| layers = self.get_layers() |
| if self.RWKV_ARMT and not layers[0].generate_mode: |
| input1 = dict() |
| input2 = dict() |
| for item in seg_kwargs: |
| if isinstance(seg_kwargs[item], torch.Tensor): |
| |
| input1[item] = seg_kwargs[item][:, :-self.num_mem_tokens] |
| input2[item] = seg_kwargs[item][:, -self.num_mem_tokens:] |
| else: |
| input1[item] = seg_kwargs[item] |
| input2[item] = seg_kwargs[item] |
| |
| self.generate_mode(True) |
| out = self.model(**input1) |
| self.generate_mode(False) |
| state_tmp = tuple([torch.clone(state) for state in out['state']]) |
| out = ARMTOutput(**{k: torch.clone(t) if isinstance(t, torch.Tensor) else t for k, t in out.items()}) |
| input2['state'] = out['state'] |
| _ = self.model(**input2) |
| out['state'] = state_tmp |
| |
| |
| |
| else: |
| out = self.model(**seg_kwargs) |
|
|
| if self.attend_to_previous_input and self.previous_input is not None: |
| out['logits'] = out['logits'][:, self.previous_input.size(1):] |
| out = self.process_output(out, labels, labels_mask, **kwargs) |
| self.previous_input = current_input_ids |
| return out |
|
|
| def process_input(self, input_ids, **kwargs): |
| memory_state, sink = self.set_memory(input_ids.shape) |
| seg_kwargs = dict(**kwargs) |
| inputs_embeds = kwargs.get('inputs_embeds') |
| if inputs_embeds is None: |
| inputs_embeds = self.model.get_input_embeddings()(input_ids) |
| if self.use_sink: |
| inputs_embeds = torch.cat([sink, inputs_embeds, memory_state], dim=1) |
| else: |
| inputs_embeds = torch.cat([inputs_embeds, memory_state], dim=1) |
| |
| seg_kwargs['input_ids'] = None |
| seg_kwargs['inputs_embeds'] = inputs_embeds |
| if kwargs.get('attention_mask') is not None: |
| seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], dtype=inputs_embeds.dtype) |
| if kwargs.get('prev_attn_mask') is not None: |
| prev_seg_attn_mask = self.pad_prev_seg_attn_mask(kwargs['prev_attn_mask'], dtype=inputs_embeds.dtype) |
| seg_kwargs['attention_mask'] = torch.cat([prev_seg_attn_mask, seg_kwargs['attention_mask']], dim=-1) |
| if 'prev_attn_mask' in seg_kwargs: |
| seg_kwargs.pop('prev_attn_mask') |
| seg_kwargs['output_hidden_states'] = True |
|
|
| if self.wrap_pos: |
| num_pos_embs = self.model.transformer.wpe.weight.shape[0] |
| ordinary_pos = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device) |
| write_pos = torch.arange(num_pos_embs - self.num_mem_tokens, num_pos_embs, dtype=torch.long, device=input_ids.device) |
| seg_kwargs['position_ids'] = torch.cat([ |
| ordinary_pos, |
| write_pos |
| ]).long().unsqueeze(0) |
| return seg_kwargs |
|
|
| |
|
|
| def pad_attention_mask(self, attention_mask, dtype=float): |
| if self.num_mem_tokens in {0, None}: |
| return attention_mask |
| else: |
| shape = list(attention_mask.shape) |
| if len(shape) == 4: |
|
|
| shape[-1] += self.num_mem_tokens + self.use_sink |
| shape[-2] += self.num_mem_tokens + self.use_sink |
| mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) |
| mask[..., int(self.use_sink):-self.num_mem_tokens, int(self.use_sink):-self.num_mem_tokens] = attention_mask |
| if self.use_sink: |
| mask[..., 0, 1:] = 0 |
| mask[..., :-self.num_mem_tokens, -self.num_mem_tokens:] = 0 |
| |
| if not os.environ.get("NOT_INVERT_ATTN_MASK"): |
| mask = invert_attn_mask(mask, dtype) |
| else: |
| shape[-1] += self.num_mem_tokens + self.use_sink |
| mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) |
| mask[..., int(self.use_sink):-self.num_mem_tokens] = attention_mask |
| return mask.to(dtype) |
|
|
| def pad_prev_seg_attn_mask(self, prev_seg_attn_mask, dtype=float): |
| if self.num_mem_tokens in {0, None}: |
| return prev_seg_attn_mask |
| else: |
| shape = list(prev_seg_attn_mask.shape) |
| if len(shape) == 4: |
| shape[-2] += self.num_mem_tokens + self.use_sink |
| mask = torch.ones(*shape, dtype=dtype).to(prev_seg_attn_mask.device) |
| mask[..., int(self.use_sink):-self.num_mem_tokens, :] = prev_seg_attn_mask |
| if self.use_sink: |
| mask[..., 0, :] = 0 |
| if not os.environ.get("NOT_INVERT_ATTN_MASK"): |
| mask = invert_attn_mask(mask, dtype) |
| else: |
| mask = prev_seg_attn_mask |
| return mask.to(dtype) |
| |
| def process_output(self, model_outputs, labels, labels_mask, **kwargs): |
| |
| if (self.num_mem_tokens not in {0, None}) and not self.RWKV_ARMT: |
| out = CausalLMOutputWithCrossAttentions() |
| out['logits'] = model_outputs.logits[:, int(self.use_sink):-self.num_mem_tokens] |
| if kwargs.get('output_hidden_states'): |
| out['hidden_states'] = [lh[:, int(self.use_sink):-self.num_mem_tokens] for lh in model_outputs.hidden_states] |
| if kwargs.get('output_attentions'): |
| out['attentions'] = model_outputs['attentions'] |
| else: |
| out = model_outputs |
|
|
| if labels is not None: |
| labels = labels[..., 1:].contiguous() |
| flat_labels = labels.view(-1) |
| |
| if labels_mask is not None: |
| flat_mask = labels_mask[..., :-1].contiguous().view(-1) |
| flat_labels = flat_labels[flat_mask] |
| |
| |
| if CUT_CROSS_ENTROPY_AVAILABLE and hasattr(self.model, 'embed_out'): |
| |
| if 'hidden_states' in model_outputs and model_outputs.hidden_states is not None: |
| |
| hidden_states = model_outputs.hidden_states[-1] |
| |
| if self.num_mem_tokens not in {0, None}: |
| hidden_states = hidden_states[:, int(self.use_sink):-self.num_mem_tokens] |
| |
| hidden_states = hidden_states[..., :-1, :].contiguous() |
| flat_hidden_states = hidden_states.view(-1, hidden_states.size(-1)) |
| |
| if labels_mask is not None: |
| flat_hidden_states = flat_hidden_states[flat_mask] |
| |
| |
| lm_head_weights = self.model.embed_out.weight |
| |
| |
| ce_loss = linear_cross_entropy( |
| flat_hidden_states, |
| lm_head_weights, |
| flat_labels, |
| reduction='sum' |
| ) |
| else: |
| |
| logits = out['logits'][..., :-1, :].contiguous() |
| flat_logits = logits.view(-1, logits.size(-1)) |
| if labels_mask is not None: |
| flat_logits = flat_logits[flat_mask] |
| ce_loss_fn = CrossEntropyLoss(reduction='sum') |
| ce_loss = ce_loss_fn(flat_logits, flat_labels) |
| else: |
| |
| logits = out['logits'][..., :-1, :].contiguous() |
| flat_logits = logits.view(-1, logits.size(-1)) |
| if labels_mask is not None: |
| flat_logits = flat_logits[flat_mask] |
| ce_loss_fn = CrossEntropyLoss(reduction='sum') |
| ce_loss = ce_loss_fn(flat_logits, flat_labels) |
| |
| if labels_mask is not None: |
| denom = labels_mask[..., :-1].contiguous().view(-1).sum() |
| else: |
| denom = (flat_labels != -100).sum() |
| denom = torch.clamp(denom, min=1) |
| out['ce_loss'] = ce_loss / denom |
|
|
| if kwargs.get('use_cache', False): |
| out['past_key_values'] = model_outputs.past_key_values |
| if self.act_on and self.act_type == 'model': |
| out['remainders'] = model_outputs['remainders'] |
| out['n_updates'] = model_outputs['n_updates'] |
| return out |
| |
| def generate(self, input_ids, attention_mask, zero_mem=False, **generate_kwargs): |
| if zero_mem: |
| self.zero_mem() |
| |
| |
| self.generate_mode(True) |
| seg_kwargs = self.process_input(input_ids, attention_mask=attention_mask) |
| out = self.model.generate( |
| inputs_embeds=seg_kwargs['inputs_embeds'][:, :-self.num_mem_tokens], |
| attention_mask=seg_kwargs['attention_mask'][:, :-self.num_mem_tokens], |
| **generate_kwargs |
| ) |
| self.generate_mode(False) |
| return out |
| |
| def update_past_key_values_sw(self, past_key_values, window_size): |
| past_key_values = past_key_values.to_legacy_cache() |
| past_key_values = [ |
| [ |
| k_or_v[..., -(window_size+self.use_sink):, :] |
| for k_or_v in seg_kv |
| ] |
| for seg_kv in past_key_values |
| ] |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| return past_key_values |
| |
| def greedy_generate_sw(self, input_ids, attention_mask, prev_attn_mask, **generate_kwargs): |
| self.generate_mode(True) |
| window_size = generate_kwargs['window_size'] |
| max_new_tokens = generate_kwargs['max_new_tokens'] |
| past_key_values = self.update_past_key_values_sw(generate_kwargs['past_key_values'], window_size) |
| eos_token_id = generate_kwargs['eos_token_id'] |
| prev_attn_mask_2d = prev_attn_mask.clone() |
| attention_mask_2d = attention_mask.clone() |
| |
| attention_mask = attn_mask_to_4d(attention_mask, upper=False, query_len=attention_mask.size(-1)) |
| prev_attn_mask = attn_mask_to_4d(prev_attn_mask, upper=True, query_len=attention_mask.size(-1)) |
| seg_kwargs = self.process_input(input_ids=input_ids, attention_mask=attention_mask, prev_attn_mask=prev_attn_mask, past_key_values=past_key_values) |
| seg_kwargs['inputs_embeds'] = seg_kwargs['inputs_embeds'][..., :-self.num_mem_tokens, :] |
| seg_kwargs['attention_mask'] = seg_kwargs['attention_mask'][..., :-self.num_mem_tokens, :-self.num_mem_tokens] |
| outputs = self.model(**seg_kwargs, use_cache=True) |
| |
| next_token_logits = outputs.logits[:, -1, :] |
|
|
| past_key_values = outputs.past_key_values |
| past_key_values = self.update_past_key_values_sw(past_key_values, window_size) |
|
|
| generated_ids = None |
| sw_attention_mask = torch.cat([prev_attn_mask_2d, torch.ones(attention_mask_2d.size(0), 1).to(prev_attn_mask_2d.device), attention_mask_2d], dim=-1) |
|
|
| for i in range(max_new_tokens): |
| |
| next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
| |
| if generated_ids is not None: |
| generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
| else: |
| generated_ids = next_token_id |
| next_input = next_token_id |
| |
| sw_attention_mask = torch.cat([sw_attention_mask, torch.ones_like(next_token_id).to(sw_attention_mask.device)], dim=-1)[..., -window_size-1-self.use_sink:] |
| with torch.no_grad(): |
| outputs = self.model( |
| input_ids=next_input, |
| attention_mask=sw_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=True, |
| cache_position=torch.full((1,), window_size + i + input_ids.size(-1) + self.use_sink).to(input_ids.device) |
| ) |
| past_key_values = self.update_past_key_values_sw(outputs.past_key_values, window_size) |
| next_token_logits = outputs.logits[:, -1, :] |
| |
| if (next_token_id[:, 0] == eos_token_id).all(): |
| break |
| self.generate_mode(False) |
| return generated_ids |
| |
|
|
| def apply_layers(self, hidden_states, causal_mask, position_ids, cache_position, position_embeddings, update_mem=True): |
| if not update_mem: |
| tmp = [] |
| for i in range(len(self.layers)): |
| tmp.append(self.layers[i].forward) |
| self.layers[i].forward = self.layers[i].forward_no_update |
|
|
| for layer in self.get_layers(): |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| )[0] |
|
|
| if not update_mem: |
| for i, layer in enumerate(self.get_layers()): |
| layer.forward = tmp[i] |
| return hidden_states |
| |
| |
| def gptneox_forward_act(self, inputs_embeds, labels=None, labels_mask=None, zero_mem=False, attention_mask=None, **kwargs): |
| |
| drop = self.model.gpt_neox.emb_dropout |
| hidden_states = drop(inputs_embeds) |
| seq_length = hidden_states.shape[1] |
| cache_position = torch.arange(0, seq_length, device=hidden_states.device) |
| position_ids = cache_position.unsqueeze(0) |
|
|
| position_embeddings = self.model.gpt_neox.rotary_emb(hidden_states, position_ids) |
| causal_mask = self.model.gpt_neox._update_causal_mask( |
| attention_mask, hidden_states, cache_position, None, False |
| ) |
|
|
| out, (remainders, n_updates) = self.act( |
| state=hidden_states, |
| inputs=hidden_states, |
| fn_no_update=lambda *args, **kwargs: self.apply_layers(*args, **kwargs, update_mem=False), |
| fn_update=self.apply_layers, |
| time_enc=self.timing_signal, |
| pos_enc=self.position_signal, |
| max_hop=self.depth, |
| causal_mask=causal_mask, |
| position_ids=position_ids, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings |
| ) |
| hidden_states = self.model.gpt_neox.final_layer_norm(out) |
|
|
| lm_logits = self.model.embed_out(hidden_states) |
| return ARMTOutput(logits=lm_logits, n_updates=n_updates, remainders=remainders) |
|
|
| class AssociativeRecurrentWrapper(torch.nn.Module): |
| def __init__(self, memory_cell, **rmt_kwargs): |
| super().__init__() |
| |
| self.memory_cell = memory_cell |
| self.rmt_config = rmt_kwargs |
| self.last_state = None |
|
|
| def gradient_checkpointing_enable(self, *args, **kwargs): |
| self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs) |
|
|
| def process_segment(self, segment_kwargs, next_seg_len=None): |
| sliding_window = self.rmt_config['sliding_window'] if 'sliding_window' in self.rmt_config else False |
| attend_to_previous_input = self.rmt_config['attend_to_previous_input'] if 'attend_to_previous_input' in self.rmt_config else False |
| attn_mask = segment_kwargs['attention_mask'] |
| seg_len = segment_kwargs['input_ids'].size(-1) |
|
|
| segment_kwargs['use_cache'] = sliding_window |
| if segment_kwargs.get('past_key_values') is None: |
| segment_kwargs['past_key_values'] = None |
| if segment_kwargs.get('prev_attn_mask') is None: |
| segment_kwargs['prev_attn_mask'] = None |
| segment_kwargs['zero_mem'] = False |
| if sliding_window or attend_to_previous_input: |
| segment_kwargs['attention_mask'] = attn_mask_to_4d(attn_mask, upper=False, query_len=seg_len) |
| |
| if 'state' in segment_kwargs and segment_kwargs['state'] is None: |
| segment_kwargs.pop('state') |
| |
| num_mem_tokens = self.memory_cell.num_mem_tokens |
| cell_out = self.memory_cell(**segment_kwargs) |
| state = cell_out.get('state') |
| if (sliding_window or attend_to_previous_input) and next_seg_len is not None: |
| prev_attn_mask = attn_mask_to_4d(attn_mask, upper=True, query_len=next_seg_len) |
| else: |
| prev_attn_mask = None |
| if sliding_window: |
| past_key_values = [ |
| [ |
| k_or_v[..., -(num_mem_tokens+seg_len):k_or_v.size(-2)-num_mem_tokens, :].detach() |
| for k_or_v in seg_kv |
| ] |
| for seg_kv in cell_out['past_key_values'] |
| ] |
| if not isinstance(cell_out['past_key_values'], tuple) and not isinstance(cell_out['past_key_values'], list): |
| past_key_values = cell_out['past_key_values'].from_legacy_cache(past_key_values) |
| else: |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| else: |
| past_key_values = None |
| next_segment_kwargs = dict() |
| next_segment_kwargs['use_cache'] = sliding_window |
| next_segment_kwargs['past_key_values'] = past_key_values |
| next_segment_kwargs['prev_attn_mask'] = prev_attn_mask |
| next_segment_kwargs['zero_mem'] = False |
| if state is not None: |
| next_segment_kwargs['state'] = state |
| return cell_out, next_segment_kwargs |
|
|
| def forward(self, |
| input_ids, |
| labels=None, |
| labels_mask=None, |
| inputs_embeds=None, |
| attention_mask=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| input_segmented=False, |
| output_only_last_segment=False, |
| use_previous_batch_state=torch.zeros(1), |
| num_items_in_batch=None, |
| **kwargs |
| ): |
| if input_segmented: |
| n_segs = input_ids.shape[1] if not (input_ids is None) else inputs_embeds.shape[1] |
| segmented = [dict( |
| input_ids=input_ids[:, i] if not (input_ids is None) else None, |
| inputs_embeds=inputs_embeds[:, i] if not (inputs_embeds is None) else None, |
| attention_mask=attention_mask[:, i], |
| labels=labels[:, i] if not (labels is None) else None, |
| labels_mask=labels_mask[:, i] if not (labels_mask is None) else None, |
| ) for i in range(n_segs)] |
| labels = torch.cat([labels[:, i] for i in range(n_segs)], dim=1) |
| if labels_mask is not None: |
| labels_mask = torch.cat([labels_mask[:, i] for i in range(n_segs)], dim=1) |
| else: |
| segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, labels_mask=labels_mask) |
| |
| cell_outputs = [] |
| if not use_previous_batch_state.all() or self.last_state is None: |
| self.memory_cell.zero_mem() |
| state = None |
| else: |
| self.memory_cell.detach_mem() |
| state = self.last_state |
| next_seg_kwargs = dict(state=state) |
| for seg_num, segment in enumerate(segmented): |
| if seg_num != len(segmented) - 1: |
| next_seg_len = segmented[seg_num + 1]['input_ids'].size(-1) |
| else: |
| next_seg_len = None |
| |
| segment_with_kwargs = dict(**segment, **next_seg_kwargs) |
| if kwargs.get('num_items_in_batch') is not None: |
| segment_with_kwargs['num_items_in_batch'] = kwargs['num_items_in_batch'] |
| cell_out, next_seg_kwargs = self.process_segment(segment_with_kwargs, next_seg_len=next_seg_len) |
| if (not output_only_last_segment) or (seg_num == len(segmented) - 1): |
| cell_outputs.append(cell_out) |
|
|
| out = self.process_outputs(cell_outputs, labels=labels, |
| labels_mask=labels_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| num_items_in_batch=kwargs.get('num_items_in_batch')) |
| |
| if not self.training: |
| self.memory_cell.zero_mem() |
| self.last_state = None |
| return out |
|
|
| def segment(self, **kwargs): |
| segments = [] |
| for k, tensor in kwargs.items(): |
| if tensor is not None: |
| k_segments = self.split_tensor(tensor) |
| for s, k_seg in enumerate(k_segments): |
| if s < len(segments): |
| segments[s][k] = k_seg |
| else: |
| segments.append({k: k_seg}) |
|
|
| return segments |
| |
| def split_tensor(self, tensor): |
| align = self.rmt_config.get('segment_alignment') |
| segment_size = self.rmt_config.get('segment_size') |
| if align in {'left', None}: |
| split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]] |
| segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] |
| elif align in {'right', None}: |
| split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1] |
| segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] |
| elif align == 'center': |
| n_seg = math.ceil(tensor.shape[1] / segment_size) |
| segments = torch.chunk(tensor, n_seg, dim=1) |
| else: |
| raise NotImplementedError |
| return segments |
|
|
| def process_outputs(self, cell_outputs, **kwargs): |
| out = ARMTOutput() |
| full_logits = torch.cat([o.logits for o in cell_outputs], dim=1) |
| |
| labels = kwargs.get('labels') |
| if labels is not None: |
| labels = labels[:, -full_logits.size(1):] |
| shift_labels = labels[..., 1:].contiguous() |
| flat_labels = shift_labels.view(-1) |
| |
| labels_mask = kwargs.get('labels_mask') |
| if labels_mask is not None: |
| labels_mask = labels_mask[:, -full_logits.size(1):] |
| shift_mask = labels_mask[..., :-1].contiguous() |
| flat_labels = flat_labels[shift_mask.view(-1)] |
| |
| |
| if CUT_CROSS_ENTROPY_AVAILABLE and hasattr(self.memory_cell.model, 'embed_out'): |
| |
| if cell_outputs and 'hidden_states' in cell_outputs[-1] and cell_outputs[-1].hidden_states is not None: |
| |
| full_hidden_states = torch.cat([o.hidden_states[-1] for o in cell_outputs], dim=1) |
| |
| shift_hidden_states = full_hidden_states[..., :-1, :].contiguous() |
| flat_hidden_states = shift_hidden_states.view(-1, shift_hidden_states.size(-1)) |
| |
| if labels_mask is not None: |
| flat_hidden_states = flat_hidden_states[shift_mask.view(-1)] |
| |
| |
| lm_head_weights = self.memory_cell.model.embed_out.weight |
| |
| |
| loss = linear_cross_entropy( |
| flat_hidden_states, |
| lm_head_weights, |
| flat_labels, |
| reduction='sum' |
| ) |
| else: |
| |
| shift_logits = full_logits[..., :-1, :].contiguous() |
| flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
| if labels_mask is not None: |
| flat_logits = flat_logits[shift_mask.view(-1)] |
| loss_fct = CrossEntropyLoss(reduction='sum') |
| loss = loss_fct(flat_logits, flat_labels) |
| else: |
| |
| shift_logits = full_logits[..., :-1, :].contiguous() |
| flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
| if labels_mask is not None: |
| flat_logits = flat_logits[shift_mask.view(-1)] |
| loss_fct = CrossEntropyLoss(reduction='sum') |
| loss = loss_fct(flat_logits, flat_labels) |
| |
| if labels_mask is not None: |
| |
| denom = labels_mask[..., :-1].contiguous().view(-1).sum() |
| else: |
| denom = (flat_labels != -100).sum() |
| denom = torch.clamp(denom, min=1) |
| out['loss'] = loss / denom |
| else: |
| out['loss'] = 0 |
| if ('HF_Trainer' not in os.environ) or not os.environ['HF_Trainer']: |
| out['ce_loss'] = out['loss'] |
| |
| out['logits'] = full_logits |
| segment_keys = ['loss', 'logits'] |
| if kwargs.get('output_attentions'): |
| segment_keys.append('attentions') |
| if kwargs.get('output_hidden_states'): |
| |
| if all(hasattr(o, 'hidden_states') and o.hidden_states is not None for o in cell_outputs): |
| full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*[o.hidden_states for o in cell_outputs])]) |
| segment_keys.append('hidden_states') |
| out['hidden_states'] = full_hidden_states |
| if ('HF_Trainer' not in os.environ) or not os.environ['HF_Trainer']: |
| for seg_num, o in enumerate(cell_outputs): |
| for key, value in o.items(): |
| if any([sk in key for sk in segment_keys]): |
| out[f'{key}_{seg_num}'] = value |
|
|
| remainders = [] |
| n_updates = [] |
| act_on = self.rmt_config['act_on'] if 'act_on' in self.rmt_config else False |
| if act_on: |
| if self.memory_cell.act_type != 'model': |
| for layer in self.memory_cell.get_layers(): |
| remainders.append(layer.remainders / layer.segments_passed) |
| n_updates.append(layer.n_updates / layer.segments_passed) |
| remainders = torch.mean(torch.stack(remainders, dim=0)) |
| n_updates = torch.mean(torch.stack(n_updates, dim=0)) |
| else: |
| remainders = torch.mean(torch.stack([o['remainders'] for o in cell_outputs], dim=0)) |
| n_updates = torch.mean(torch.stack([o['n_updates'] for o in cell_outputs], dim=0)) |
| out['n_updates'] = n_updates.detach().cpu() |
| out['remainders'] = remainders.detach().cpu() |
| time_penalty = self.rmt_config['time_penalty'] |
| out['loss'] = out['loss'] + time_penalty * remainders |
| |
| return out |
| |
| def generate(self, input_ids, attention_mask, **generate_kwargs): |
| self.memory_cell.zero_mem() |
| segmented = self.segment(input_ids=input_ids, attention_mask=attention_mask) |
| next_seg_kwargs = dict() |
| for seg_num, segment in enumerate(segmented[:-1]): |
| next_seg_len = segmented[seg_num + 1]['input_ids'].size(-1) |
| _, next_seg_kwargs = self.process_segment(dict(**segment, **next_seg_kwargs), next_seg_len=next_seg_len) |
| |
| final_segment = segmented[-1] |
| assert next_seg_kwargs.get('past_key_values') is None or isinstance(next_seg_kwargs.get('past_key_values'), Cache), "Sliding Window generation is not implemented for legacy cache" |
| if next_seg_kwargs.get('past_key_values') is not None: |
| prev_attn_mask = segmented[-2]['attention_mask'] |
| legacy_cache = next_seg_kwargs['past_key_values'].to_legacy_cache() |
| seg_len = segmented[-2]['input_ids'].size(-1) |
| cache = DynamicCache().from_legacy_cache(legacy_cache) |
| generate_kwargs['past_key_values'] = cache |
| generate_kwargs['window_size'] = seg_len |
| final_segment['prev_attn_mask'] = prev_attn_mask |
| out = self.memory_cell.greedy_generate_sw(**final_segment, **generate_kwargs) |
| return out |
| else: |
| out = self.memory_cell.generate(**final_segment, **generate_kwargs) |
| return out |
|
|
|
|
| |
| import math |
| import torch |
| from torch.nn import CrossEntropyLoss |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| from transformers.cache_utils import Cache, DynamicCache |
| from torch.nn.functional import relu as r |
| import torch.nn.functional as F |
| import os |
|
|
| |
|
|
|
|
| class ARMTConfig(PretrainedConfig): |
| model_type = "armt" |
|
|
| def __init__(self, |
| base_model_name=None, |
| base_model_config=None, |
| num_mem_tokens=16, |
| d_mem=512, |
| |
| segment_size=512, |
| segment_alignment="left", |
| sliding_window=False, |
| attend_to_previous_input=False, |
| use_sink=False, |
| layers_attr="model.layers", |
| wrap_pos=False, |
| correction=True, |
| n_heads=1, |
| use_denom=True, |
| gating=False, |
| freeze_mem=False, |
| act_on=False, |
| max_hop=4, |
| act_type="associative", |
| act_format="linear", |
| noisy_halting=False, |
| constant_depth=False, |
| time_penalty=0.0, |
| **kwargs): |
| super().__init__(**kwargs) |
| |
| if (base_model_name is not None) and (base_model_config is not None): |
| raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided. Set the other to None.") |
| self.base_model_name = base_model_name |
| |
| self.base_model_config = base_model_config |
| self.num_mem_tokens = num_mem_tokens |
| self.d_mem = d_mem |
|
|
| self.segment_size = segment_size |
| self.segment_alignment = segment_alignment |
| self.sliding_window = sliding_window |
| self.attend_to_previous_input = attend_to_previous_input |
| self.use_sink = use_sink |
| self.layers_attr = layers_attr |
| self.wrap_pos = wrap_pos |
| self.correction = correction |
| self.n_heads = n_heads |
| self.use_denom = use_denom |
| self.gating = gating |
| self.freeze_mem = freeze_mem |
| self.act_on = act_on |
| self.max_hop = max_hop |
| self.act_type = act_type |
| self.act_format = act_format |
| self.noisy_halting = noisy_halting |
| self.constant_depth = constant_depth |
| self.time_penalty = time_penalty |
|
|
| def get(self, attr: str, default=None): |
| if hasattr(self, attr): |
| return getattr(self, attr) |
| else: |
| return default |
|
|
|
|
| class ARMTForCausalLM(PreTrainedModel): |
| config_class = ARMTConfig |
|
|
| def __init__(self, config: ARMTConfig, **kwargs): |
| super().__init__(config, **kwargs) |
| from transformers import AutoConfig, AutoModelForCausalLM |
| |
| |
| base_model = None |
| if getattr(config, 'base_model_name', None) is not None and getattr(config, 'base_model_config', None) is not None: |
| raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided in ARMTConfig.") |
| bm_cfg = getattr(config, 'base_model_config', None) |
| if bm_cfg is not None: |
| |
| if isinstance(bm_cfg, PretrainedConfig) and getattr(bm_cfg, 'model_type', None) != ARMTConfig.model_type: |
| resolved_cfg = bm_cfg |
| elif isinstance(bm_cfg, dict): |
| if 'model_type' not in bm_cfg: |
| raise ValueError("`base_model_config` dict must include a 'model_type' key (e.g., 'gpt_neox', 'llama').") |
| config_cls_or_instance = AutoConfig.for_model(bm_cfg['model_type']) |
| |
| if isinstance(config_cls_or_instance, PretrainedConfig): |
| resolved_cfg = config_cls_or_instance |
| for k, v in bm_cfg.items(): |
| setattr(resolved_cfg, k, v) |
| else: |
| resolved_cfg = config_cls_or_instance.from_dict(bm_cfg) |
| elif isinstance(bm_cfg, str): |
| |
| resolved_cfg = AutoConfig.from_pretrained(bm_cfg) |
| else: |
| raise TypeError("`base_model_config` must be a transformers.PretrainedConfig, dict, or str (name/path)") |
| base_model = AutoModelForCausalLM.from_config(resolved_cfg) |
| elif getattr(config, 'base_model_name', None): |
| base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name) |
| else: |
| raise ValueError("ARMTForCausalLM requires either `base_model_config` or `base_model_name` in ARMTConfig.") |
|
|
| self.armt_config = config |
| |
| |
| memory_cell = AssociativeMemoryCell( |
| base_model=base_model, |
| num_mem_tokens=config.num_mem_tokens, |
| d_mem=config.d_mem, |
| layers_attr=config.layers_attr, |
| wrap_pos=config.wrap_pos, |
| correction=config.correction, |
| n_heads=config.n_heads, |
| use_denom=config.use_denom, |
| gating=config.gating, |
| freeze_mem=config.freeze_mem, |
| act_on=config.act_on, |
| max_hop=config.max_hop, |
| act_type=config.act_type, |
| |
| constant_depth=config.get('constant_depth', False), |
| act_format=config.get('act_format', 'linear'), |
| noisy_halting=config.get('noisy_halting', False), |
| attend_to_previous_input=config.attend_to_previous_input, |
| use_sink=config.use_sink |
| ) |
| |
| |
| self.armt = AssociativeRecurrentWrapper( |
| memory_cell, |
| segment_size=config.segment_size, |
| segment_alignment=config.segment_alignment, |
| sliding_window=config.sliding_window, |
| attend_to_previous_input=config.attend_to_previous_input, |
| act_on=config.act_on, |
| time_penalty=config.time_penalty |
| ) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| labels=None, |
| labels_mask=None, |
| inputs_embeds=None, |
| attention_mask=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| input_segmented=False, |
| output_only_last_segment=False, |
| num_items_in_batch=None, |
| ): |
| return self.armt( |
| input_ids=input_ids, |
| labels=labels, |
| labels_mask=labels_mask, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| input_segmented=input_segmented, |
| output_only_last_segment=output_only_last_segment, |
| num_items_in_batch=num_items_in_batch, |
| ) |
|
|
| def generate(self, *args, **kwargs): |
| return self.armt.generate(*args, **kwargs) |
|
|
| def load_state_dict(self, state_dict, strict=True, assign=False): |
| try: |
| return super().load_state_dict(state_dict, strict, assign) |
| except RuntimeError: |
| print("Failed to load state, retrying with ARMT loader.") |
| self.armt.load_state_dict(state_dict, strict=True, assign=assign) |
| print("Success!") |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, config=None, *args, **kwargs): |
| |
| return super().from_pretrained(pretrained_model_name_or_path, *args, config=config, **kwargs) |
|
|
| def gradient_checkpointing_enable(self, *args, **kwargs): |
| self.armt.gradient_checkpointing_enable(*args, **kwargs) |
|
|
| |
| import math |
| import os |
| import inspect |
| from typing import Optional, Tuple, Callable |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.cache_utils import DynamicCache |
| import warnings |
| |
|
|
| try: |
| from liger_kernel.transformers import apply_liger_kernel_to_llama |
| LIGER_KERNEL_AVAILABLE = True |
| except ImportError: |
| print("*** Can't import liger_kernel ***") |
| LIGER_KERNEL_AVAILABLE = False |
| except Exception as e: |
| print("*** Can't import liger_kernel ***") |
| raise e |
|
|
| |
| |
|
|
| def reverse_invert_attn_mask(mask: torch.Tensor) -> torch.Tensor: |
| if os.environ.get("NOT_INVERT_ATTN_MASK"): |
| return mask |
| mask = mask.clone().long() |
| mask[mask > -1] = 1 |
| mask[mask < -1] = 0 |
| return mask |
|
|
| def attn_mask_to_2d(mask: torch.Tensor) -> torch.Tensor: |
| mask = reverse_invert_attn_mask(mask) |
| mask = torch.any(mask, dim=-2) |
| mask = torch.any(mask, dim=1) |
| return mask.long() |
|
|
| def is_empty_past_key_values(past_key_values: Optional[DynamicCache], layer_idx: int) -> bool: |
| if past_key_values is None: |
| return True |
| if len(past_key_values.layers) == 0: |
| return True |
| if len(past_key_values.layers) <= layer_idx: |
| return True |
| if past_key_values.layers[layer_idx].keys is None: |
| return True |
| return False |
|
|
| invert_attn_mask = lambda mask, dtype: (_invert_attn_mask(mask, dtype) if not os.environ.get("NOT_INVERT_ATTN_MASK") else mask) |
|
|
| def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor: |
| if not isinstance(t, torch.Tensor): |
| return t |
| |
| if t.dim() >= 2 and t.size(1) == seq_len: |
| return t[:, start_idx:end_idx, ...] |
| return t |
|
|
| class InnerLoopAssociativeLayerWrapper(nn.Module): |
| """ |
| A per-layer wrapper that performs associative read/write within the layer by |
| splitting the incoming full sequence into fixed-size segments on the fly. |
| |
| Unlike the outer-loop design (which segments inputs before the model), this |
| module receives the full, unsplit hidden sequence and internally iterates |
| over segments: |
| 1) Optional associative READ is applied to the segment's hidden states |
| based on the current associative memory (W_mem, z). |
| 2) Memory tokens are appended to the segment and the underlying transformer |
| layer is executed only on this augmented segment. |
| 3) The resulting memory token outputs are used to WRITE/update the |
| associative memory. |
| 4) The transformed real-token outputs replace the corresponding slice in |
| the layer output for the full sequence. |
| |
| This preserves identical behavior w.r.t. memory math while avoiding any |
| outer recurrent wrapper. |
| """ |
|
|
| def __init__( |
| self, |
| layer: nn.Module, |
| d_model: int, |
| num_mem_tokens: int, |
| d_mem: int, |
| segment_size: int, |
| n_heads: int = 1, |
| correction: bool = True, |
| use_denom: bool = True, |
| gating: bool = False, |
| use_sink: bool = False, |
| sliding_window: bool = False, |
| get_memory_fn: Optional[Callable[[], torch.Tensor]] = None, |
| get_sink_fn: Optional[Callable[[], Optional[torch.Tensor]]] = None, |
| rotary_fn: Optional[Callable[[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]] = None, |
| read_prev_states_fn: Optional[Callable[[int, int, torch.device, torch.dtype], Tuple[torch.Tensor, Optional[torch.Tensor]]]] = None, |
| write_states_fn: Optional[Callable[[int, torch.Tensor, Optional[torch.Tensor]], None]] = None, |
| info: Optional[dict] = None, |
| ) -> None: |
| super().__init__() |
| self.info = info |
| self.layer = layer |
| self.d_model = d_model |
| self.num_mem_tokens = int(num_mem_tokens or 0) |
| self.d_mem = d_mem |
| self.segment_size = int(segment_size) |
| self.n_heads = n_heads |
| self.gating = gating |
| self.use_denom = use_denom |
| self.correction = correction |
| self.use_sink = bool(use_sink) |
| self.sliding_window = bool(sliding_window) |
|
|
| |
| nu = 3 |
| self.d_key = 2 * nu * d_mem |
|
|
| assert self.d_mem % n_heads == 0 and self.d_model % n_heads == 0 |
|
|
| |
| layer_dtype = next(self.layer.parameters()).dtype |
|
|
| |
| self.W_mq = nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) |
| self.W_mk = nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) |
| self.W_mv = nn.Linear(d_model, d_model, bias=False, dtype=layer_dtype) |
| if gating: |
| self.W_mb = nn.Linear(d_model, d_model, dtype=layer_dtype) |
| else: |
| self.W_mb = nn.Linear(d_model, n_heads, dtype=layer_dtype) |
| torch.nn.init.zeros_(self.W_mv.weight) |
|
|
| self.phi = DPFP(nu) |
|
|
| |
| self.generate_mode = False |
| self.seg_num = 0 |
|
|
| |
| |
| self._get_memory = get_memory_fn |
| self._get_sink = get_sink_fn |
| self._rotary_fn = rotary_fn |
| self._read_prev_states = read_prev_states_fn |
| self._write_states = write_states_fn |
|
|
| self.memory_state = None |
|
|
| |
| def _to_heads(self, x: torch.Tensor) -> torch.Tensor: |
| bsz, seq_len, d_model = x.shape |
| x = x.reshape(bsz, seq_len, self.n_heads, d_model // self.n_heads) |
| x = x.permute(0, 2, 1, 3) |
| return x |
|
|
| def _from_heads(self, x: torch.Tensor) -> torch.Tensor: |
| bsz, n_heads, seq_len, d_head = x.shape |
| x = x.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * d_head) |
| return x |
|
|
| |
| def associate(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError("associate() is unused in inner-loop; uses local memory helpers instead") |
|
|
| |
| def update_mem(self, mem_tokens: torch.Tensor) -> None: |
| raise NotImplementedError("update_mem() is unused in inner-loop; uses local memory helpers instead") |
|
|
| |
| def zero_mem(self) -> None: |
| self.memory_state = None |
|
|
| def detach_mem(self) -> None: |
| self.memory_state = (self.memory_state[0].detach(), self.memory_state[1].detach()) if self.memory_state is not None else None |
|
|
| def freeze_mem(self) -> None: |
| self.W_mb.weight.requires_grad = False |
| self.W_mb.bias.requires_grad = False |
| self.W_mq.weight.requires_grad = False |
| self.W_mk.weight.requires_grad = False |
| self.W_mv.weight.requires_grad = False |
|
|
| |
| def _get_segment_positions( |
| self, position_ids: Optional[torch.LongTensor], start: int, end: int, device: torch.device |
| ) -> torch.LongTensor: |
| |
| if position_ids is not None: |
| return position_ids[:, start:end] |
| else: |
| position_ids = torch.arange(start, end, device=device).long().unsqueeze(0) |
| return position_ids |
|
|
|
|
| def pad_attention_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype): |
| if self.num_mem_tokens in {0, None} and not self.use_sink: |
| return attention_mask |
| shape = list(attention_mask.shape) |
| if len(shape) == 4: |
| shape[-1] += self.num_mem_tokens + int(self.use_sink) |
| shape[-2] += self.num_mem_tokens + int(self.use_sink) |
| mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) |
| mask[..., int(self.use_sink):-self.num_mem_tokens, int(self.use_sink):-self.num_mem_tokens] = attention_mask |
| if self.use_sink: |
| mask[..., 0, 1:] = 0 |
| mask[..., :-self.num_mem_tokens, -self.num_mem_tokens:] = 0 |
| elif len(shape) == 2: |
| shape[-1] += self.num_mem_tokens + int(self.use_sink) |
| mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) |
| mask[..., int(self.use_sink):-self.num_mem_tokens] = attention_mask |
| else: |
| raise ValueError("Attention mask must be 2D or 4D") |
| return mask.to(dtype) |
|
|
|
|
| def _get_memory_tokens(self, batch_size: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
| if self._get_memory is None or self.num_mem_tokens == 0: |
| return None, None |
| memory = self._get_memory() |
| sink = self._get_sink() if self.use_sink and self._get_sink is not None else None |
| mem = memory.unsqueeze(0).expand(batch_size, -1, -1) |
| if sink is not None: |
| sink = sink.unsqueeze(0).expand(batch_size, -1, -1) |
| return mem, sink |
|
|
| |
| def _alloc_initial_mem(self, device: torch.device, dtype: torch.dtype): |
| W_mem = torch.zeros( |
| 1, |
| self.n_heads, |
| self.d_key // self.n_heads, |
| self.d_model // self.n_heads, |
| device=device, |
| dtype=dtype, |
| ) |
| z = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, device=device, dtype=dtype) if self.use_denom else None |
| return W_mem, z |
|
|
| def _associate_with_mem(self, hidden_states: torch.Tensor, W_mem: torch.Tensor, z: Optional[torch.Tensor]) -> torch.Tensor: |
| q = self._to_heads(self.W_mq(hidden_states)) |
| mq = self.phi(q) |
| mq = F.normalize(mq, dim=-1, p=2.0) |
| num = torch.einsum("ihjk,ihkt->ihjt", mq, W_mem) |
| if self.use_denom and z is not None: |
| denom = torch.einsum("ihk,ihjk->ihj", z, mq)[..., None] + 1e-5 |
| hs = num / denom |
| else: |
| hs = num |
| return self._from_heads(hs) |
|
|
| def _update_mem_with_mem( |
| self, |
| mem_tokens: torch.Tensor, |
| W_mem: torch.Tensor, |
| z: Optional[torch.Tensor], |
| first_seg: bool, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], bool]: |
| k = self._to_heads(self.W_mk(mem_tokens)) |
| mk = self.phi(k) |
| mk = F.normalize(mk, dim=-1, p=2.0) |
|
|
| new_mv = self._to_heads(self.W_mv(mem_tokens)) |
| if not first_seg: |
| num = torch.einsum("ihjk,ihkt->ihjt", mk, W_mem) |
| if self.use_denom and z is not None: |
| denom = torch.einsum("ihj,ihkj->ihk", z, mk)[..., None] + 1e-5 |
| prev_mv = num / denom |
| if self.correction: |
| new_info_coef = ( |
| 1 - denom / (torch.linalg.norm(mk, dim=-1) ** 2)[..., None] |
| ) |
| new_info_coef = torch.clip(new_info_coef, 0, 1).detach() |
| else: |
| new_info_coef = 1 |
| else: |
| prev_mv = num |
| new_info_coef = 1 |
| else: |
| prev_mv = torch.zeros_like(new_mv, device=new_mv.device) |
| new_info_coef = 1 |
|
|
| mv = new_mv - prev_mv |
| mb = self._to_heads(torch.sigmoid(self.W_mb(mem_tokens))) |
| einop = f"ihjk,ihjt,ihj{'t' if self.gating else 'x'}->ihkt" |
| associations = torch.einsum(einop, mk, mv, mb) |
| W_mem = W_mem + associations |
| if self.use_denom and z is not None: |
| z = z + (new_info_coef * mk).sum(dim=-2) |
| return W_mem, z, False |
|
|
| |
| def forward(self, hidden_states: torch.Tensor, *args, **kwargs): |
| """ |
| Convert positional args of the wrapped HF block into keyword args by |
| introspecting the block's forward signature. This prevents accidental |
| misplacement (e.g., a cache object being treated as attention_mask). |
| """ |
| |
| try: |
| sig = inspect.signature(self.layer.forward) |
| params = list(sig.parameters.values()) |
| |
| param_names = [p.name for p in params[1:]] |
| |
| if len(param_names) > 0 and param_names[0] in {"hidden_states", "x"}: |
| param_names = param_names[1:] |
| except Exception: |
| param_names = [] |
|
|
| for idx, arg in enumerate(args): |
| if idx >= len(param_names): |
| break |
| name = param_names[idx] |
| if name not in kwargs: |
| kwargs[name] = arg |
|
|
| |
| if "layer_past" in kwargs and "past_key_values" not in kwargs: |
| layer_past = kwargs.pop("layer_past") |
| try: |
| if isinstance(layer_past, DynamicCache): |
| kwargs["past_key_values"] = layer_past |
| else: |
| kwargs["past_key_values"] = DynamicCache.from_legacy_cache(layer_past) |
| except Exception: |
| kwargs["past_key_values"] = layer_past |
|
|
| |
| attention_mask = kwargs.pop("attention_mask", None) |
|
|
| return self.forward_horizontal(hidden_states, attention_mask, **kwargs) |
| |
| |
| def forward_horizontal(self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs): |
| assert not self.generate_mode, "Generate mode is not supported for horizontal forward" |
| assert attention_mask is None or attention_mask.dim() == 4, "Attention mask must be 4D" |
| using_cache = not is_empty_past_key_values(kwargs.get("past_key_values"), self.info['layer']) |
| assert not using_cache or (kwargs.get('past_attn_mask') is not None and kwargs.get('past_attn_mask').shape[-1] == self.segment_size), "When using cache, past_attn_mask must be provided and have the same length as the segment size" |
|
|
| if isinstance(hidden_states, (tuple, list)): |
| hidden_states = hidden_states[0] |
| bsz, seq_len, _ = hidden_states.shape |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(bsz, seq_len, device=hidden_states.device, dtype=hidden_states.dtype) |
| attention_mask = attn_mask_to_4d(attention_mask, upper=False, query_len=seq_len) |
| attention_mask = invert_attn_mask(attention_mask, hidden_states.dtype) |
| out_full = [] |
|
|
| |
| if self.memory_state is not None: |
| W_mem, z = self.memory_state |
| first_seg = False |
| else: |
| W_mem, z = self._alloc_initial_mem(hidden_states.device, hidden_states.dtype) |
| first_seg = True |
|
|
|
|
| |
| |
| provided_cache = kwargs.get("past_key_values") |
| past_key_values = provided_cache if provided_cache is not None else DynamicCache() |
| past_attn_mask = kwargs.get('past_attn_mask') if using_cache else None |
| present_kv = None |
|
|
| |
| |
| seg_num = 0 |
| for start in range(0, seq_len, self.segment_size+self.num_mem_tokens+int(self.use_sink)): |
| real_start = start+int(self.use_sink) |
| real_end = min(real_start + self.segment_size, seq_len-self.num_mem_tokens) |
| end = real_end+self.num_mem_tokens |
| seg_aug = hidden_states[:, start:end, :] |
| seg_len = real_end - real_start |
|
|
| attn_mask = attention_mask[:, :, real_start:real_end, real_start:real_end] |
|
|
| |
|
|
| |
| is_last_segment = (end >= seq_len) |
|
|
|
|
| if not first_seg: |
| assoc = self._associate_with_mem(seg_aug, W_mem, z) |
| seg_aug = assoc + seg_aug |
|
|
| |
| seg_aug_len = seg_aug.size(1) |
| |
| if self.sliding_window: |
| |
| |
| base_cur4d = reverse_invert_attn_mask(attn_mask) |
| seg_mask = self.pad_attention_mask(base_cur4d, dtype=seg_aug.dtype) |
| seg_mask = invert_attn_mask(seg_mask, seg_aug.dtype) |
|
|
| if past_attn_mask is not None: |
|
|
| base_past4d = attn_mask_to_4d(attn_mask_to_2d(past_attn_mask), upper=True, query_len=seg_aug_len) |
| if self.use_sink: |
| base_past4d[:, :, 0, :] = 0 |
| |
| base_past4d = invert_attn_mask(base_past4d, seg_aug.dtype) |
|
|
| |
| |
| seg_mask = torch.cat([base_past4d, seg_mask], dim=-1) |
| if os.environ.get("ARMT_DEBUG_SW"): |
| print(f"[H-SEG] L{self.info['layer']} seg_len={seg_len} seg_aug_len={seg_aug_len} mask={tuple(seg_mask.shape)}") |
| else: |
| base_cur4d = reverse_invert_attn_mask(attn_mask) |
| seg_mask = self.pad_attention_mask(base_cur4d, dtype=seg_aug.dtype) |
| seg_mask = invert_attn_mask(seg_mask, seg_aug.dtype) |
| |
| |
| seg_pos_ids = self._get_segment_positions(kwargs.get("position_ids", None), start, end, seg_aug.device) |
|
|
| |
| seg_args = tuple(segment_tensor(a, start, end, seq_len) if isinstance(a, torch.Tensor) else a for a in args) |
| seg_kwargs = {k: segment_tensor(v, start, end, seq_len) for k, v in kwargs.items()} |
|
|
|
|
| |
| |
| seg_kwargs["attention_mask"] = seg_mask.to(seg_aug.dtype) |
| if seg_pos_ids is not None: |
| seg_kwargs["position_ids"] = seg_pos_ids |
| seg_kwargs["use_cache"] = self.sliding_window |
| |
| if self.sliding_window: |
| seg_kwargs["past_key_values"] = past_key_values |
| else: |
| |
| seg_kwargs.pop("layer_past", None) |
| seg_kwargs.pop("cache_position", None) |
| seg_kwargs.pop("past_key_values", None) |
| seg_kwargs["use_cache"] = False |
|
|
| if self._rotary_fn is not None and seg_pos_ids is not None: |
| cos, sin = self._rotary_fn(seg_aug, seg_pos_ids) |
| seg_kwargs["position_embeddings"] = (cos, sin) |
|
|
|
|
| layer_out = self.layer(seg_aug, *seg_args, **seg_kwargs) |
| if self.sliding_window: |
| assert past_key_values is not None, "Past key values object must be provided" |
| |
| if os.environ.get("ARMT_DEBUG_SW"): |
| k = past_key_values.layers[self.info['layer']].keys |
| v = past_key_values.layers[self.info['layer']].values |
| print(f"[H-CACHE:pre] L{self.info['layer']} K={tuple(k.shape) if k is not None else None} V={tuple(v.shape) if v is not None else None}") |
| past_key_values = self.update_past_key_values_sw(past_key_values, self.segment_size) |
| if os.environ.get("ARMT_DEBUG_SW"): |
| k = past_key_values.layers[self.info['layer']].keys |
| v = past_key_values.layers[self.info['layer']].values |
| print(f"[H-CACHE:post] L{self.info['layer']} K={tuple(k.shape) if k is not None else None} V={tuple(v.shape) if v is not None else None}") |
| if isinstance(layer_out, tuple): |
| seg_out = layer_out[0] |
| else: |
| seg_out = layer_out |
|
|
| seg_mem_out = seg_out[:, -self.num_mem_tokens:, :] |
| W_mem, z, first_seg = self._update_mem_with_mem( |
| seg_mem_out, W_mem, z, first_seg |
| ) |
| first_seg = False |
|
|
| out_full.append(seg_out) |
|
|
| past_attn_mask = attn_mask |
| seg_num += 1 |
|
|
| merged = torch.cat(out_full, dim=1) |
|
|
| |
| self.memory_state = (W_mem, z) |
|
|
| if isinstance(layer_out, tuple): |
| YELLOW = "\033[93m" |
| if len(layer_out) == 1: |
| return (merged,) |
| elif len(layer_out) == 2: |
| warnings.warn(f"{YELLOW}Last attention was not tested for horizontal forward{RESET}") |
| return (merged, None) |
| elif len(layer_out) == 3: |
| warnings.warn(f"{YELLOW}Last attention and kv states were not tested for horizontal forward{RESET}") |
| return (merged, None, present_kv) |
| else: |
| raise ValueError(f"Expected 1, 2 or 3 elements in layer output, got {len(layer_out)}") |
| else: |
| return merged |
|
|
| def update_past_key_values_sw(self, past_key_values, window_size): |
| """ |
| Update past key values for sliding window attention. |
| This keeps only the most recent tokens within the window size. |
| """ |
| if is_empty_past_key_values(past_key_values, self.info['layer']): |
| return None |
| |
| |
| if hasattr(past_key_values, 'to_legacy_cache'): |
| legacy = past_key_values.to_legacy_cache() |
| legacy = past_key_values.to_legacy_cache() |
| |
| |
| k, v = legacy[self.info['layer']] |
| k = k[..., -window_size-self.num_mem_tokens:-self.num_mem_tokens, :] |
| v = v[..., -window_size-self.num_mem_tokens:-self.num_mem_tokens, :] |
| |
| past_key_values.layers[self.info['layer']].keys = k |
| past_key_values.layers[self.info['layer']].values = v |
| return past_key_values |
|
|
|
|
| class InnerLoopARMTForCausalLM(PreTrainedModel): |
| """ |
| Drop-in ARMT model that installs InnerLoopAssociativeLayerWrapper into a base |
| HF Causal LM. All segmentation happens inside each wrapped layer; no outer |
| recurrent driver is needed. |
| """ |
|
|
| |
| config_class = ARMTConfig |
|
|
| def __init__(self, config: PretrainedConfig, **kwargs): |
| global LIGER_KERNEL_AVAILABLE |
| super().__init__(config, **kwargs) |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| |
| base_model = None |
| bm_cfg = getattr(config, "base_model_config", None) |
| bm_name = getattr(config, "base_model_name", None) |
|
|
| if 'llama' not in bm_name: |
| LIGER_KERNEL_AVAILABLE = False |
| os.environ["ARMT_DISABLE_LIGER_KERNEL"] = "1" |
| if LIGER_KERNEL_AVAILABLE and not os.environ.get("ARMT_DISABLE_LIGER_KERNEL"): |
| apply_liger_kernel_to_llama() |
|
|
| if bm_cfg is not None and bm_name is not None: |
| raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided in config.") |
| if bm_cfg is not None: |
| if isinstance(bm_cfg, PretrainedConfig) and getattr(bm_cfg, "model_type", None) != getattr(config, "model_type", None): |
| resolved_cfg = bm_cfg |
| elif isinstance(bm_cfg, dict): |
| from transformers import AutoConfig as HF_AutoConfig |
|
|
| if "model_type" not in bm_cfg: |
| raise ValueError("`base_model_config` dict must include a 'model_type' key.") |
| cfg_or_inst = HF_AutoConfig.for_model(bm_cfg["model_type"]) |
| if isinstance(cfg_or_inst, PretrainedConfig): |
| resolved_cfg = cfg_or_inst |
| for k, v in bm_cfg.items(): |
| setattr(resolved_cfg, k, v) |
| else: |
| resolved_cfg = cfg_or_inst.from_dict(bm_cfg) |
| elif isinstance(bm_cfg, str): |
| from transformers import AutoConfig as HF_AutoConfig |
|
|
| resolved_cfg = HF_AutoConfig.from_pretrained(bm_cfg) |
| else: |
| raise TypeError("`base_model_config` must be a transformers.PretrainedConfig, dict, or str.") |
| base_model = AutoModelForCausalLM.from_config(resolved_cfg) |
| elif bm_name is not None: |
| from transformers import AutoModelForCausalLM as HF_AutoModelForCausalLM |
|
|
| base_model = HF_AutoModelForCausalLM.from_pretrained(bm_name) |
| else: |
| raise ValueError("InnerLoopARMTForCausalLM requires either `base_model_config` or `base_model_name` in the config.") |
|
|
| |
| self.model = base_model |
|
|
| |
| self.num_mem_tokens = int(getattr(config, "num_mem_tokens", 0) or 0) |
| self.d_mem = int(getattr(config, "d_mem", 512)) |
| self.segment_size = int(getattr(config, "segment_size", 512)) |
| self.segment_alignment = getattr(config, "segment_alignment", "left") |
| if self.segment_alignment != 'left': |
| raise |
| self.layers_attr = getattr(config, "layers_attr", "model.layers") |
| self.correction = bool(getattr(config, "correction", True)) |
| self.n_heads = int(getattr(config, "n_heads", 1)) |
| self.use_denom = bool(getattr(config, "use_denom", True)) |
| self.gating = bool(getattr(config, "gating", False)) |
| self.freeze_mem_flag = bool(getattr(config, "freeze_mem", False)) |
| self.use_sink = bool(getattr(config, "use_sink", False)) |
| self.sliding_window = bool(getattr(config, "sliding_window", False)) |
|
|
| |
| emb = self.model.get_input_embeddings() |
| d_model = emb.embedding_dim |
| memory_dim = getattr(self.model.config, "n_embd", getattr(self.model.config, "hidden_size", d_model)) |
| |
| |
| |
| |
| |
| |
| memory_weights = torch.empty( |
| (self.num_mem_tokens, memory_dim), device=emb.weight.device, dtype=emb.weight.dtype |
| ) |
| |
| torch.nn.init.normal_(memory_weights, mean=0.0, std=0.02) |
| self.memory = nn.Parameter(memory_weights, requires_grad=True) |
| if self.use_sink: |
| self.sink = nn.Parameter( |
| torch.randn((1, memory_dim), device=emb.weight.device, dtype=emb.weight.dtype), requires_grad=True |
| ) |
| |
| def _get_layers_from_model(model_root: nn.Module): |
| obj = model_root |
| for attr in self.layers_attr.split("."): |
| obj = getattr(obj, attr) |
| return obj |
|
|
| layers = _get_layers_from_model(self.model) |
| rotary_fn = None |
| if hasattr(self.model, "model") and hasattr(self.model.model, "rotary_emb"): |
| rotary_fn = self.model.model.rotary_emb |
| elif hasattr(self.model, "gpt_neox") and hasattr(self.model.gpt_neox, "rotary_emb"): |
| rotary_fn = self.model.gpt_neox.rotary_emb |
|
|
| for i in range(len(layers)): |
| layers[i] = InnerLoopAssociativeLayerWrapper( |
| layer=layers[i], |
| d_model=d_model, |
| num_mem_tokens=self.num_mem_tokens, |
| d_mem=self.d_mem, |
| segment_size=self.segment_size, |
| n_heads=self.n_heads, |
| correction=self.correction, |
| use_denom=self.use_denom, |
| gating=self.gating, |
| use_sink=self.use_sink, |
| sliding_window=self.sliding_window, |
| get_memory_fn=lambda self_ref=self: self_ref.memory, |
| get_sink_fn=lambda self_ref=self: getattr(self_ref, "sink", None), |
| rotary_fn=rotary_fn, |
| info={"layer": i}, |
| ) |
|
|
| if self.freeze_mem_flag: |
| for layer in _get_layers_from_model(self.model): |
| layer.freeze_mem() |
|
|
|
|
| |
| self.get_layers = lambda: _get_layers_from_model(self.model) |
|
|
| self.vertical_mode = False |
|
|
| |
| def generate_mode(self, is_on: bool): |
| for layer in self.get_layers(): |
| layer.generate_mode = is_on |
|
|
| def zero_mem(self): |
| """Reset memory state for all layers.""" |
| for layer in self.get_layers(): |
| layer.zero_mem() |
|
|
| def detach_mem(self): |
| """Detach memory state for all layers.""" |
| for layer in self.get_layers(): |
| layer.detach_mem() |
|
|
| def augment_sequence(self, hidden_states: torch.Tensor, mem: torch.Tensor, sink: torch.Tensor = None): |
| segments = torch.split(hidden_states, self.segment_size, dim=1) |
| if sink is not None: |
| augmented_segments = [torch.cat([sink.to(segment.dtype).to(segment.device), segment, mem.to(segment.dtype).to(segment.device)], dim=1) for segment in segments] |
| else: |
| augmented_segments = [torch.cat([segment, mem.to(segment.dtype).to(segment.device)], dim=1) for segment in segments] |
| augmented_sequence = torch.cat(augmented_segments, dim=1) |
|
|
| return augmented_sequence |
|
|
| def clean_sequence(self, hidden_states: torch.Tensor): |
| augmented_segments = torch.split(hidden_states, self.segment_size+self.num_mem_tokens+int(self.use_sink), dim=1) |
| segments = [segment[:, int(self.use_sink):-self.num_mem_tokens] for segment in augmented_segments] |
| return torch.cat(segments, dim=1) |
|
|
| def augment_attention_mask(self, attention_mask: torch.Tensor): |
| segments = torch.split(attention_mask, self.segment_size, dim=1) |
| if self.use_sink: |
| augmented_segments = [torch.cat([ |
| torch.ones(segment.shape[0], 1, device=segment.device, dtype=segment.dtype), |
| segment, |
| torch.ones(segment.shape[0], self.num_mem_tokens, device=segment.device, dtype=segment.dtype) |
| ], dim=1) for segment in segments] |
| else: |
| augmented_segments = [torch.cat([ |
| segment, |
| torch.ones(segment.shape[0], self.num_mem_tokens, device=segment.device, dtype=segment.dtype) |
| ], dim=1) for segment in segments] |
| augmented_attention_mask = torch.cat(augmented_segments, dim=1) |
| return augmented_attention_mask |
|
|
| def augment_labels(self, labels): |
| if labels is None: |
| return None |
| first = labels[:, :1] |
| segments = torch.split(labels[:, 1:], self.segment_size, dim=1) |
| if self.use_sink: |
| augmented_segments = [torch.cat([ |
| -100 * torch.ones(segment.shape[0], 1, device=segment.device, dtype=segment.dtype), |
| segment, |
| -100 * torch.ones(segment.shape[0], self.num_mem_tokens, device=segment.device, dtype=segment.dtype) |
| ], dim=1) for segment in segments] |
| else: |
| augmented_segments = [torch.cat([ |
| segment, |
| -100 * torch.ones(segment.shape[0], self.num_mem_tokens, device=segment.device, dtype=segment.dtype) |
| ], dim=1) for segment in segments] |
| augmented_segments = torch.cat(augmented_segments, dim=1) |
| augmented_labels = torch.cat([first, augmented_segments], dim=1) |
| return augmented_labels |
|
|
| def augment(self, input_ids, inputs_embeds, attention_mask, labels): |
| if input_ids is not None: |
| assert inputs_embeds is None, "input_ids and inputs_embeds cannot be provided together" |
| hidden_states = self.model.get_input_embeddings()(input_ids) |
| elif inputs_embeds is not None: |
| hidden_states = inputs_embeds |
| else: |
| raise ValueError("Either input_ids or inputs_embeds must be provided") |
| mem = self.memory.unsqueeze(0).expand(hidden_states.size(0), -1, -1) |
| sink = self.sink.unsqueeze(0).expand(hidden_states.size(0), -1, -1) if self.use_sink else None |
|
|
| augmented_hidden_states = self.augment_sequence(hidden_states, mem, sink) |
| augmented_attention_mask = self.augment_attention_mask(attention_mask) |
| augmented_labels = self.augment_labels(labels) |
| return augmented_hidden_states, augmented_attention_mask, augmented_labels |
|
|
| def forward( |
| self, |
| input_ids=None, |
| labels=None, |
| labels_mask=None, |
| inputs_embeds=None, |
| attention_mask=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| output_only_last_segment=False, |
| num_items_in_batch=None, |
| use_cache=None, |
| past_key_values=None, |
| ): |
| if labels_mask is not None: |
| assert labels_mask.any(), "labels_mask must not be all zeros" |
| |
| effective_labels = labels |
| if labels is not None and labels_mask is not None: |
| if isinstance(labels_mask, torch.Tensor): |
| mask_bool = labels_mask.bool() if labels_mask.dtype != torch.bool else labels_mask |
| effective_labels = labels.masked_fill(~mask_bool, -100) |
| else: |
| raise ValueError("labels_mask must be a torch.Tensor") |
|
|
| if attention_mask is None: |
| if input_ids is not None: |
| attention_mask = torch.ones(input_ids.shape[0], input_ids.shape[1], device=input_ids.device, dtype=input_ids.dtype) |
| else: |
| attention_mask = torch.ones(inputs_embeds.shape[0], inputs_embeds.shape[1], device=inputs_embeds.device, dtype=inputs_embeds.dtype) |
| |
| if self.vertical_mode: |
| return self.forward_vertical( |
| input_ids=input_ids, |
| labels=effective_labels, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| output_only_last_segment=output_only_last_segment, |
| num_items_in_batch=num_items_in_batch, |
| use_cache=use_cache, |
| past_key_values=past_key_values, |
| past_attn_mask=None |
| ) |
| else: |
| return self.forward_horizontal( |
| input_ids=input_ids, |
| labels=effective_labels, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| output_only_last_segment=output_only_last_segment, |
| num_items_in_batch=num_items_in_batch, |
| use_cache=use_cache, |
| past_key_values=past_key_values |
| ) |
| def forward_vertical( |
| self, |
| input_ids=None, |
| labels=None, |
| inputs_embeds=None, |
| attention_mask=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| output_only_last_segment=False, |
| num_items_in_batch=None, |
| use_cache=None, |
| past_key_values=None, |
| past_attn_mask=None, |
| ): |
| assert not self.training or os.environ.get("ARMT_DISABLE_LIGER_KERNEL"), "Liger kernel is not supported for training in vertical mode, to disable liger kernel, set ARMT_DISABLE_LIGER_KERNEL=1" |
| |
| if input_ids is not None: |
| assert inputs_embeds is None |
| B, L = input_ids.shape |
| device = input_ids.device |
| elif inputs_embeds is not None: |
| B, L, _ = inputs_embeds.shape |
| device = inputs_embeds.device |
| else: |
| raise ValueError("Either input_ids or inputs_embeds must be provided") |
| dtype = next(self.model.parameters()).dtype |
|
|
| augmented_hidden_states, augmented_attention_mask, augmented_labels = self.augment(input_ids, inputs_embeds, attention_mask, labels) |
|
|
| |
| def split_tensor(tensor: torch.Tensor, segment_size: int): |
| return torch.split(tensor, segment_size+self.num_mem_tokens+int(self.use_sink), dim=1) |
|
|
| |
| |
| seg_inputs_embeds = split_tensor(augmented_hidden_states, self.segment_size) |
| seg_attention_mask = split_tensor(augmented_attention_mask, self.segment_size) if attention_mask is not None else None |
| seg_labels = split_tensor(augmented_labels, self.segment_size) if labels is not None else None |
| |
| num_segments = len(seg_inputs_embeds) |
| segments = [] |
| for i in range(num_segments): |
| segments.append({ |
| "inputs_embeds": seg_inputs_embeds[i], |
| "attention_mask": None if seg_attention_mask is None else seg_attention_mask[i], |
| "labels": None if seg_labels is None else seg_labels[i], |
| }) |
|
|
| |
| use_sliding = bool(self.sliding_window) |
| shared_cache = past_key_values if (use_sliding and past_key_values is not None) else (DynamicCache() if use_sliding else None) |
| past_attn_mask = past_attn_mask if use_sliding else None |
| |
| pos_offset = 0 |
|
|
| |
| seg_outputs = [] |
| layers = self.get_layers() |
| for seg in segments: |
| seg_len = seg["inputs_embeds"].size(1) |
| if seg.get("attention_mask") is None: |
| base_2d = torch.ones(B, seg_len, device=device, dtype=dtype) |
| else: |
| base_2d = seg["attention_mask"] |
| cur4d = attn_mask_to_4d(base_2d, upper=False, query_len=seg_len) |
| cur4d = invert_attn_mask(cur4d, dtype=dtype) |
|
|
| |
| position_ids = torch.arange(pos_offset, pos_offset + seg_len, device=device).long().unsqueeze(0) |
|
|
| |
| orig_forwards = [ly.forward for ly in layers] |
| seg_past_attn_mask = past_attn_mask |
| def _inject_mask(orig_fn, mask): |
| def _wrapped(hs, *a, **k): |
| |
| if mask is not None: |
| if 'past_attn_mask' not in k: |
| k['past_attn_mask'] = mask |
| |
| if 'past_key_values' not in k or k['past_key_values'] is None: |
| k['past_key_values'] = shared_cache |
| |
| if hasattr(k['past_key_values'], 'layers') and len(k['past_key_values'].layers) < len(layers): |
| |
| needed = len(layers) - len(k['past_key_values'].layers) |
| k['past_key_values'].layers.extend([type(k['past_key_values'].layers[0])() for _ in range(needed)]) |
| k['use_cache'] = True |
| return orig_fn(hs, *a, **k) |
| return _wrapped |
| for i, ly in enumerate(layers): |
| ly.forward = _inject_mask(orig_forwards[i], seg_past_attn_mask) |
|
|
| out = self.model( |
| input_ids=seg.get("input_ids"), |
| inputs_embeds=seg.get("inputs_embeds"), |
| attention_mask=cur4d, |
| position_ids=position_ids, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| use_cache=use_sliding, |
| past_key_values=shared_cache if use_sliding else None, |
| ) |
| if os.environ.get("ARMT_DEBUG_SW"): |
| print(f"[V-SEG] seg_len={seg_len} cur4d={tuple(cur4d.shape)} pos=({int(position_ids[0,0])},{int(position_ids[0,-1])})") |
| if hasattr(out, 'past_key_values') and out.past_key_values is not None: |
| try: |
| k = out.past_key_values.layers[0].keys |
| v = out.past_key_values.layers[0].values |
| print(f"[V-CACHE:out] L0 K={tuple(k.shape) if k is not None else None} V={tuple(v.shape) if v is not None else None}") |
| except Exception: |
| pass |
| |
| for i, ly in enumerate(layers): |
| ly.forward = orig_forwards[i] |
| seg_outputs.append(out) |
|
|
| if use_sliding: |
| |
| shared_cache = out.past_key_values if hasattr(out, 'past_key_values') else shared_cache |
| if os.environ.get("ARMT_DEBUG_SW") and shared_cache is not None: |
| try: |
| k = shared_cache.layers[0].keys |
| v = shared_cache.layers[0].values |
| print(f"[V-CACHE:posttrim] L0 K={tuple(k.shape) if k is not None else None} V={tuple(v.shape) if v is not None else None}") |
| except Exception: |
| pass |
| past_attn_mask = cur4d[:, :, int(self.use_sink):-self.num_mem_tokens, int(self.use_sink):-self.num_mem_tokens] |
| pos_offset += seg_len |
|
|
| |
| |
| full_logits = torch.cat([o.logits for o in seg_outputs], dim=1) if len(seg_outputs) > 1 else seg_outputs[0].logits |
|
|
| result = {} |
| result["logits"] = self.clean_sequence(full_logits) |
|
|
| |
| if labels is not None: |
| labels = labels[:, -full_logits.size(1):] |
| shift_labels = labels[..., 1:].contiguous() |
| flat_labels = shift_labels.view(-1) |
|
|
| if labels_mask is not None: |
| labels_mask = labels_mask[:, -full_logits.size(1):] |
| shift_mask = labels_mask[..., :-1].contiguous() |
| else: |
| shift_mask = None |
|
|
| shift_logits = full_logits[..., :-1, :].contiguous() |
| flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
| if shift_mask is not None: |
| flat_logits = flat_logits[shift_mask.view(-1)] |
| flat_labels = flat_labels[shift_mask.view(-1)] |
| loss_fct = CrossEntropyLoss(reduction='sum') |
| loss = loss_fct(flat_logits, flat_labels) |
|
|
| if labels_mask is not None: |
| denom = labels_mask[..., :-1].contiguous().view(-1).sum() |
| else: |
| denom = (flat_labels != -100).sum() |
| denom = torch.clamp(denom, min=1) |
| result["loss"] = loss / denom |
| |
| if output_hidden_states: |
| if all(getattr(o, 'hidden_states', None) is not None for o in seg_outputs): |
| |
| full_hidden_states = tuple([ |
| torch.cat(layer_hs, dim=1) |
| for layer_hs in zip(*[o.hidden_states for o in seg_outputs]) |
| ]) |
| result["hidden_states"] = full_hidden_states |
|
|
| return result |
| |
| |
| def forward_horizontal( |
| self, |
| input_ids=None, |
| labels=None, |
| inputs_embeds=None, |
| attention_mask=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| output_only_last_segment=False, |
| num_items_in_batch=None, |
| use_cache=None, |
| past_key_values=None, |
| ): |
| augmented_hidden_states, augmented_attention_mask, augmented_labels = self.augment(input_ids, inputs_embeds, attention_mask, labels) |
| out = self.model( |
| labels=augmented_labels, |
| inputs_embeds=augmented_hidden_states, |
| attention_mask=augmented_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| use_cache=use_cache, |
| past_key_values=past_key_values, |
| ) |
| if not LIGER_KERNEL_AVAILABLE: |
| out.logits = self.clean_sequence(out.logits) |
| self.zero_mem() |
| return out |
|
|
| def generate(self, input_ids, attention_mask=None, **generate_kwargs): |
| """ |
| Generate tokens using the inner-loop model with proper sliding window attention. |
| This method should produce the same logits as the forward method for alignment. |
| """ |
|
|
| warnings.warn("Efficient generation is not implemented") |
| if self.sliding_window: |
| return self._generate_inefficient(input_ids, attention_mask, **generate_kwargs) |
| else: |
| |
| return self._generate_inefficient(input_ids, attention_mask, **generate_kwargs) |
| |
| |
| def _generate_standard(self, input_ids, attention_mask=None, **generate_kwargs): |
| """Standard generation without sliding window.""" |
| generate_kwargs['output_scores'] = generate_kwargs.get('return_logits', False) |
| generate_kwargs['return_dict_in_generate'] = generate_kwargs.get('return_logits', False) |
| generate_kwargs.pop('return_logits') |
| out = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) |
| if generate_kwargs.get('output_scores', False): |
| print(out.scores) |
| return out.sequences, out.scores |
| else: |
| return out.sequences |
| |
| def _generate_inefficient(self, input_ids, attention_mask=None, **generate_kwargs): |
| """ |
| Generate tokens using sliding window attention that matches the forward method. |
| This ensures alignment between generate and forward methods. |
| INEFFICIENT: recomputes the entire sequence on every token generation. |
| Kept for reference and testing purposes. |
| """ |
| max_new_tokens = generate_kwargs.get('max_new_tokens', 1) |
| eos_token_id = generate_kwargs.get('eos_token_id', None) |
| return_logits = generate_kwargs.get('return_logits', False) |
| |
| generated_ids = None |
| all_logits = [] |
|
|
| |
| for i in range(max_new_tokens): |
| |
| if generated_ids is not None: |
| current_input_ids = torch.cat([input_ids, generated_ids], dim=-1) |
| current_attention_mask = torch.cat([attention_mask, torch.ones_like(generated_ids)], dim=-1) |
| else: |
| current_input_ids = input_ids |
| current_attention_mask = attention_mask |
| |
| |
| |
| self.zero_mem() |
| |
| with torch.no_grad(): |
| outputs = self.forward( |
| input_ids=current_input_ids, |
| attention_mask=current_attention_mask |
| ) |
| next_token_logits = outputs.logits[:, -1, :] |
| |
| |
| next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
| |
| if generated_ids is not None: |
| generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
| else: |
| generated_ids = next_token_id |
| |
| |
| if return_logits: |
| all_logits.append(next_token_logits) |
| |
| |
| if eos_token_id is not None and (next_token_id == eos_token_id).all(): |
| break |
| |
| if return_logits: |
| |
| return generated_ids, torch.stack(all_logits, dim=1) |
| else: |
| return generated_ids |
|
|
| def _generate_sliding_window(self, input_ids, attention_mask=None, **generate_kwargs): |
| """ |
| Generate tokens using sliding window attention with efficient caching. |
| Uses the base model directly with past_key_values to avoid recomputing the entire sequence. |
| This method should produce the same logits as the forward method for alignment. |
| """ |
| self.generate_mode(True) |
| try: |
| max_new_tokens = generate_kwargs.get('max_new_tokens', 1) |
| eos_token_id = generate_kwargs.get('eos_token_id', None) |
| return_logits = generate_kwargs.get('return_logits', False) |
| |
| |
| self.zero_mem() |
| |
| |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids) |
| |
| |
| initial_outputs = self.forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask |
| ) |
| |
| |
| next_token_logits = initial_outputs.logits[:, -1, :] |
| |
| generated_ids = None |
| all_logits = [] |
| |
| |
| |
| |
| |
| base_model = self.model |
| window_size = self.segment_size + self.num_mem_tokens + int(self.use_sink) |
| |
| |
| try: |
| |
| base_outputs = base_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=True |
| ) |
| past_key_values = base_outputs.past_key_values |
| |
| |
| for i in range(max_new_tokens): |
| |
| next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
| |
| if generated_ids is not None: |
| generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
| else: |
| generated_ids = next_token_id |
| |
| |
| if return_logits: |
| all_logits.append(next_token_logits) |
| |
| |
| if eos_token_id is not None and (next_token_id == eos_token_id).all(): |
| break |
| |
| |
| with torch.no_grad(): |
| next_outputs = base_model( |
| input_ids=next_token_id, |
| attention_mask=torch.ones_like(next_token_id), |
| past_key_values=past_key_values, |
| use_cache=True |
| ) |
| next_token_logits = next_outputs.logits[:, -1, :] |
| past_key_values = next_outputs.past_key_values |
| |
| |
| if past_key_values is not None: |
| past_key_values = self.update_past_key_values_sw(past_key_values, window_size) |
| |
| except Exception as e: |
| |
| print(f"Error implementing efficient generation: {e}") |
| print("This suggests the base model doesn't support the expected interface") |
| print("Why could this happen?") |
| print("1. The base model might not support past_key_values") |
| print("2. The attention mask handling might be incompatible") |
| print("3. The memory tokens might interfere with caching") |
| print("4. The inner loop wrapper might not be compatible with base model caching") |
| raise RuntimeError(f"Efficient generation failed: {e}") |
| |
| if return_logits: |
| return generated_ids, torch.stack(all_logits, dim=1) |
| else: |
| return generated_ids |
| finally: |
| self.generate_mode(False) |
|
|
| def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): |
| try: |
| return super().load_state_dict(state_dict, strict, assign) |
| except RuntimeError: |
| |
| self.model.load_state_dict(state_dict, strict=True) |
| return |
|
|
| def zero_mem(self): |
| for layer in self.get_layers(): |
| layer.zero_mem() |
|
|
| def detach_mem(self): |
| for layer in self.get_layers(): |
| layer.detach_mem() |
|
|
| def freeze_mem(self): |
| for layer in self.get_layers(): |
| layer.freeze_mem() |
|
|
|
|