| """ARB — Any Relational Bit. Core model assembly.""" |
| import warnings |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from math import ceil as _ceil |
|
|
| _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0 |
|
|
| from .config import VOCAB, HIDDEN_DIM, SPECIAL_VOCAB, CTX, THRESHOLD, CODEBOOK_DIM, CODEBOOK_SIZE, KV_LEDGER_SIZE, KQ_CACHE_SIZE, MEMGRAM_STRUCT_PRIMES, MEMGRAM_CONV_PRIMES, MEMGRAM_EMBED_DIM, MEMGRAM_KEY_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES, MG_TOP_K |
| from .kernel.ternary_scale import TScaleType, TernaryScaleTensor, TernaryRMSNorm, _HAS_TRITON |
| try: |
| from .kernel.ternary_scale import _triton_apply_accumulated_flips |
| except ImportError: |
| _triton_apply_accumulated_flips = None |
| from .converters.convert_to_ternary8 import pack_ternary |
| try: |
| from .kernel.ternary_scale import _TritonTernaryEmbedFn |
| except ImportError: |
| _TritonTernaryEmbedFn = None |
| from .sequencers import ByteEmbedding, MultimodalSequencer |
| from .vq import SharedVQ |
| from .components import ( |
| ByteHead, OutputRouter, |
| MemGram, LossComponents, LossWeights, |
| CompositeProposalHead, MoEGraph, |
| ) |
| from .decoders import VideoHead, TalkerHead |
| from .components import _BOUNDARY_TOKEN_MAP as _BOUNDARY_MAP |
| from .attention import KVLedger, KQCache, ContextAttentionScheduler |
| from .kernel.flash_vq import FlashVQCodebook |
| def _extract_boundary_from_input(x): |
| if x.dim() != 2: |
| return None |
| first_token = x[0, 0].item() |
| if first_token in _BOUNDARY_MAP: |
| return first_token |
| for tok in x[0].tolist(): |
| if tok in _BOUNDARY_MAP: |
| return tok |
| return None |
|
|
|
|
| class ARBModel(nn.Module): |
| def __init__(self, tscale_type=TScaleType.T32, threshold=THRESHOLD, |
| max_graph_hops=4, max_moe_iters=4, halt_threshold=0.99, |
| enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True, |
| enable_memory_modules=False, enable_moe=True, |
| shared_vq_size=None, kgvq_codebook_size=None, |
| enable_attention=True, enable_output_router=True, |
| enable_video_output=True, enable_talker_output=True): |
| super().__init__() |
| self.image_enabled = enable_image |
| self.audio_enabled = enable_audio |
| self.embedding = ByteEmbedding(tscale_type=tscale_type) |
| self.multimodal_sequencer = MultimodalSequencer( |
| tscale_type=tscale_type, |
| enable_text=True, enable_image=enable_image, enable_audio=enable_audio, |
| ) |
| self.text_sequencer = self.multimodal_sequencer.text |
| self.image_sequencer = self.multimodal_sequencer.image |
| self.audio_sequencer = self.multimodal_sequencer.audio |
| self.vq_enabled = enable_vq |
| self.bridge = SharedVQ( |
| codebook_size=shared_vq_size, |
| tscale_type=tscale_type, enable_image=enable_image, enable_audio=enable_audio, |
| ) if enable_vq else None |
| self.vq_to_trigram = TernaryScaleTensor(CODEBOOK_DIM, HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None |
| self.vq_to_trigram_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None |
| self.graph_enabled = enable_graph and enable_vq |
| graph_vocab_size = self.bridge.total_codebook_size if self.graph_enabled else None |
| self.threshold = threshold |
| self.moegraph = MoEGraph( |
| trigram_dim=HIDDEN_DIM, codebook_size=graph_vocab_size or CODEBOOK_SIZE, |
| max_iters=max_moe_iters, halt_threshold=halt_threshold, |
| top_k=MG_TOP_K, |
| ) if self.graph_enabled else None |
| self.byte_head = ByteHead(tscale_type=tscale_type) |
| |
| self.composite_head = CompositeProposalHead( |
| dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM, |
| k_max=K_MAX_COMPOSITES, codebook_size=kgvq_codebook_size or KGVQ_CODEBOOK_SIZE, |
| tscale_type=tscale_type, |
| ) if self.graph_enabled else None |
| self.output_router = OutputRouter(tscale_type=tscale_type, depth=3) if enable_output_router else None |
| self.video_head = VideoHead(tscale_type=tscale_type) if enable_video_output else None |
| self.talker_head = TalkerHead(tscale_type=tscale_type) if enable_talker_output else None |
| self.memgram = MemGram( |
| struct_primes=MEMGRAM_STRUCT_PRIMES, |
| conv_primes=MEMGRAM_CONV_PRIMES, |
| embed_dim=MEMGRAM_EMBED_DIM, key_dim=MEMGRAM_KEY_DIM, hidden_dim=HIDDEN_DIM, |
| ) if enable_memory_modules else None |
| self.memgram_enabled = self.memgram is not None |
|
|
| |
| self.kv_ledger = KVLedger(max_size=KV_LEDGER_SIZE) if enable_attention else None |
| self.kq_cache = KQCache(max_size=KQ_CACHE_SIZE) if enable_attention else None |
| self.attention = ContextAttentionScheduler(dim=HIDDEN_DIM) if enable_attention else None |
| self.attention_enabled = bool(enable_attention) |
|
|
| def forward(self, x, targets=None, commitment_warmup_weight=1.0, |
| act_warmup_mode=False, ponder_lambda=0.01, images=None, |
| audio=None, timestep=0, loss_weights=None, output_mode=None): |
| has_image = images is not None |
| has_audio = audio is not None |
| if has_image and (not self.image_enabled or self.image_sequencer is None): |
| raise ValueError("images provided but model has enable_image=False") |
| if has_audio and (not self.audio_enabled or self.audio_sequencer is None): |
| raise ValueError("audio provided but model has enable_audio=False") |
|
|
| embedded = self.embedding(x) |
| seq_inputs = {'text': embedded} |
| if has_image: |
| seq_inputs['image'] = images |
| if has_audio: |
| seq_inputs['audio'] = audio |
| seq_outputs = self.multimodal_sequencer(seq_inputs) |
| relational = seq_outputs['text'] |
|
|
| indices_dict = {} |
| if self.vq_enabled: |
| bridge_inputs = {'text': relational} |
| if 'image' in seq_outputs: |
| bridge_inputs['image'] = seq_outputs['image'] |
| if 'audio' in seq_outputs: |
| bridge_inputs['audio'] = seq_outputs['audio'] |
|
|
| combined, vq_losses, indices_dict = self.bridge(bridge_inputs, timestep=timestep) |
| if combined is None: |
| combined = relational |
| elif combined.shape[-1] == CODEBOOK_DIM: |
| combined = self.vq_to_trigram_norm(self.vq_to_trigram(combined)) |
| vq_loss = vq_losses.get('text_vq', torch.zeros((), device=x.device)) |
| if 'image_vq' in vq_losses: |
| vq_loss = vq_loss + vq_losses['image_vq'] |
| if 'audio_vq' in vq_losses: |
| vq_loss = vq_loss + vq_losses['audio_vq'] |
| else: |
| combined = relational |
| vq_loss = torch.zeros((), device=x.device) |
|
|
| active_mods = ['text'] |
| if has_image: |
| active_mods.append('image') |
| if has_audio: |
| active_mods.append('audio') |
| active_count = len(active_mods) |
|
|
| |
| memgram_decay_reg = torch.tensor(0.0, device=x.device) |
|
|
| if self.memgram_enabled and self.memgram is not None and self.vq_enabled: |
| vq_indices = indices_dict.get('text', torch.zeros(combined.shape[0], combined.shape[1], dtype=torch.long, device=x.device)) |
| combined = self.memgram( |
| vq_indices=vq_indices, |
| hidden_state=combined, |
| ) |
|
|
| all_indices = None |
| composite_ids = None |
| composite_vq_loss = None |
| processed = combined |
| moegraph_ponder_loss = torch.tensor(0.0, device=x.device) |
|
|
| if self.graph_enabled and self.moegraph is not None and self.vq_enabled and vq_loss is not None: |
| self.moegraph._codebook_table = self.bridge.vq.table |
| self.moegraph._codebook_embed = None |
|
|
| all_indices = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long)) |
| if has_image and 'image' in indices_dict: |
| all_indices = torch.cat([all_indices, indices_dict['image']], dim=1) |
| if has_audio and 'audio' in indices_dict: |
| all_indices = torch.cat([all_indices, indices_dict['audio']], dim=1) |
|
|
| |
| memgram_cb = None |
| if self.memgram_enabled and self.memgram is not None and self.vq_enabled: |
| vq_idx = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long)) |
| memgram_cb = self.memgram.retrieve_cb(vq_idx) |
|
|
| |
| attn_out = None |
| if self.attention_enabled and self.attention is not None and self.kv_ledger is not None: |
| attn_out = self.attention(combined, self.kv_ledger, kq_cache=self.kq_cache) |
|
|
| |
| processed, moegraph_ponder_loss = self.moegraph( |
| combined, all_indices, |
| attention_output=attn_out, |
| memgram_cb_output=memgram_cb, |
| threshold=self.threshold, |
| ) |
|
|
| |
| if self.composite_head is not None: |
| composite_ids, composite_vq_loss, _ = self.composite_head(processed.mean(dim=1)) |
|
|
| |
| self.moegraph.update_kg_edges(all_indices) |
|
|
| |
| if targets is not None or output_mode == "text": |
| logits = self.byte_head(processed) |
| elif output_mode == "video": |
| if self.video_head is None: |
| raise ValueError("output_mode='video' requested but video output is disabled") |
| logits = self.video_head(processed) |
| elif output_mode in {"audio", "talker"}: |
| if self.talker_head is None: |
| raise ValueError("audio/talker output requested but talker output is disabled") |
| logits = self.talker_head(processed) |
| elif self.training and self.output_router is not None: |
| route = self.output_router(processed, training=True) |
| route_weights, route_logits = route |
| logits = self.byte_head(processed) |
| elif self.output_router is not None: |
| route = self.output_router(processed, training=False) |
| if isinstance(route, torch.Tensor) and route.numel() > 0: |
| use_video = (route == 2).any() and self.video_head is not None |
| use_talk = (route == 3).any() and self.talker_head is not None |
| logits = self.video_head(processed) if use_video else \ |
| self.talker_head(processed) if use_talk else \ |
| self.byte_head(processed) |
| else: |
| logits = self.byte_head(processed) |
| else: |
| logits = self.byte_head(processed) |
|
|
| T_text = relational.shape[1] |
| if logits.dim() == 3 and logits.shape[-1] == VOCAB: |
| logits = logits[:, :T_text, :] |
| with torch.no_grad(): |
| self._append_predictions_to_kv(logits.argmax(dim=-1), composite_ids=composite_ids) |
| losses = None |
| if targets is not None: |
| next_byte_logits = logits[:, :-1, :].contiguous() |
| lm_loss = F.cross_entropy( |
| next_byte_logits.view(-1, VOCAB), |
| targets.contiguous().view(-1), |
| ignore_index=SPECIAL_VOCAB["PAD"] |
| ) |
| vq_component = commitment_warmup_weight * vq_loss if self.vq_enabled else None |
| losses = LossComponents( |
| lm=lm_loss, |
| vq_commitment=vq_component, |
| graph_l1=None, |
| moegraph_ponder=moegraph_ponder_loss, |
| memgram_decay_reg=memgram_decay_reg if self.memgram_enabled else None, |
| composite_vq=composite_vq_loss if self.composite_head is not None and composite_ids is not None else None, |
| weights=loss_weights if loss_weights is not None else LossWeights(), |
| ) |
|
|
| return logits, losses, all_indices, None |
|
|
| @torch.no_grad() |
| def _append_predictions_to_kv(self, pred_ids, composite_ids=None): |
| if self.kv_ledger is None or self.kq_cache is None: |
| return |
| for b in range(pred_ids.shape[0]): |
| for t in range(pred_ids.shape[1]): |
| token_id = int(pred_ids[b, t]) |
| self.kv_ledger.append(token_id) |
| self.kq_cache.append(token_id) |
| if composite_ids is None: |
| continue |
| composite_offset = self.bridge.total_codebook_size if self.vq_enabled and self.bridge is not None else 0 |
| for k in range(composite_ids.shape[1]): |
| cid = int(composite_ids[b, k]) |
| if cid >= 0: |
| self.kv_ledger.append(composite_offset + cid) |
|
|
| def _ternary_update_memory(self, accum_threshold=8, update_scales=True, |
| loss_components=None, loss_signal=None): |
| signal = loss_components.total if loss_components is not None else loss_signal |
| t_step = self._ternary_t_step(signal) |
| if signal is not None and not torch.isfinite(signal.detach()).all(): |
| warnings.warn("Non-finite loss detected — skipping ternary state update", |
| RuntimeWarning, stacklevel=2) |
| self._clear_ternary_hooks() |
| self.zero_grad(set_to_none=True) |
| return |
|
|
| if loss_components is not None: |
| self._componentwise_ternary_backward(loss_components, t_step, update_scales, accum_threshold) |
| else: |
| self._apply_regular_ternary_hooks(accum_threshold, update_scales, t_step, loss_signal) |
| self._clear_ternary_hooks() |
| self._clear_backward_update_flags() |
|
|
| def prepare_ternary_backward(self, loss_signal=None, update_scales=True): |
| """Configure streaming CUDA ternary updates before `loss.backward()`. |
| |
| BigInt-scaled dense linear backward accumulates directly into int64 |
| `corr_accum`, while legacy sparse tables still use int8 `T_accum`. |
| Calling this before backward lets the streaming path use the same |
| loss-scaled step that `_ternary_update_memory()` will finalize. |
| """ |
| t_step = self._ternary_t_step(loss_signal) |
| for module in self.modules(): |
| if hasattr(module, "T_accum") or hasattr(module, "corr_accum"): |
| module._backward_t_accum_step = t_step |
| module._backward_update_scales = bool(update_scales) |
| module._stream_backward_updates = True |
|
|
| def _clear_backward_update_flags(self): |
| for module in self.modules(): |
| for attr in ( |
| "_backward_t_accum_step", |
| "_backward_update_scales", |
| "_stream_backward_updates", |
| "_streamed_ternary_backward", |
| "_streamed_bigint_backward", |
| ): |
| if hasattr(module, attr): |
| delattr(module, attr) |
|
|
| @staticmethod |
| def _ternary_t_step(loss_signal): |
| return 1 |
|
|
| def _clear_ternary_hooks(self): |
| base_names = [ |
| "_hook_grad_T_sign", "_hook_grad_2d", "_hook_x_2d", "_hook_T", |
| "_hook_sparse_indices", "_hook_sparse_grad_sign", "_hook_sparse_T", |
| ] |
| for module in self.modules(): |
| if hasattr(module, "_T_accum_fp"): |
| delattr(module, "_T_accum_fp") |
| for hook_name in base_names: |
| if hasattr(module, hook_name): |
| delattr(module, hook_name) |
| for hook_name in list(vars(module).keys()): |
| if hook_name.startswith(( |
| "_hook_grad_T_sign_", "_hook_grad_2d_", "_hook_x_2d_", "_hook_T_", |
| "_hook_sparse_indices_", "_hook_sparse_grad_sign_", "_hook_sparse_T_", |
| )): |
| delattr(module, hook_name) |
|
|
| def _componentwise_ternary_backward(self, loss_components, t_step, update_scales, accum_threshold): |
| from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT |
|
|
| self.prepare_ternary_backward(loss_components.total, update_scales=update_scales) |
| active = [(n, t, w) for n, t, w in loss_components.active_fields |
| if t is not None and t.dim() == 0 and t.requires_grad and float(w) != 0.0] |
| for idx, (name, comp_tensor, weight) in enumerate(active): |
| retain = idx < len(active) - 1 |
| _COMPONENT_CONTEXT.set(name, weight) |
| try: |
| comp_tensor.backward(retain_graph=retain) |
| finally: |
| _COMPONENT_CONTEXT.clear() |
| self._consume_component_hooks(name, weight, t_step, update_scales, accum_threshold) |
|
|
| with torch.no_grad(): |
| for module in self.modules(): |
| if self._is_large_sparse_embedding(module): |
| continue |
| if update_scales: |
| self._step_E_from_accum(module) |
| self._apply_accumulated_flips(module, accum_threshold=accum_threshold) |
|
|
| def _consume_component_hooks(self, name, weight, t_step, update_scales, accum_threshold): |
| for module in self.modules(): |
| sparse_idx_key = f"_hook_sparse_indices_{name}" |
| sparse_grad_key = f"_hook_sparse_grad_sign_{name}" |
| sparse_t_key = f"_hook_sparse_T_{name}" |
| if hasattr(module, sparse_idx_key) and hasattr(module, sparse_grad_key): |
| setattr(module, "_hook_sparse_indices", getattr(module, sparse_idx_key)) |
| setattr(module, "_hook_sparse_grad_sign", getattr(module, sparse_grad_key)) |
| if hasattr(module, sparse_t_key): |
| setattr(module, "_hook_sparse_T", getattr(module, sparse_t_key)) |
| if update_scales and hasattr(module, "update_E"): |
| module._e_accum_threshold = 8 |
| module.update_E() |
| if hasattr(module, "T_accum"): |
| module._t_accum_step = max(1, int(round(abs(float(weight)) * t_step))) |
| if hasattr(module, "ternary_step"): |
| module.ternary_step(accum_threshold=accum_threshold) |
| for key in (sparse_idx_key, sparse_grad_key, sparse_t_key): |
| if hasattr(module, key): |
| delattr(module, key) |
| continue |
|
|
| dense_key = f"_hook_grad_T_sign_{name}" |
| dense_t_key = f"_hook_T_{name}" |
| if hasattr(module, dense_key): |
| grad_sign = getattr(module, dense_key) |
| hook_t = getattr(module, dense_t_key, None) |
| self._accumulate_component_grad_continuous( |
| module, grad_sign, weight, t_step, |
| ) |
| delattr(module, dense_key) |
| if hasattr(module, dense_t_key): |
| delattr(module, dense_t_key) |
|
|
| grad_key = f"_hook_grad_2d_{name}" |
| x_key = f"_hook_x_2d_{name}" |
| if not hasattr(module, grad_key) or not hasattr(module, x_key): |
| continue |
| comp_grad = getattr(module, grad_key) |
| comp_x = getattr(module, x_key) |
| if torch.isfinite(comp_grad).all() and torch.isfinite(comp_x).all(): |
| raw_grad = torch.clamp(comp_grad.transpose(0, 1) @ comp_x, -10.0, 10.0) |
| self._accumulate_component_grad_continuous( |
| module, raw_grad, weight, t_step, |
| ) |
| delattr(module, grad_key) |
| delattr(module, x_key) |
|
|
| def _accumulate_component_grad_continuous(self, module, raw_grad, weight, t_step): |
| """Component loss accumulation without persistent float optimizer state.""" |
| if not hasattr(module, "_T_shape"): |
| return |
| shape = tuple(int(x) for x in module._T_shape.tolist()) |
| if tuple(raw_grad.shape) != shape: |
| return |
| with torch.no_grad(): |
| step = max(1, int(round(abs(float(weight)) * t_step))) |
| if float(weight) < 0: |
| step = -step |
| if hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign"): |
| signed = raw_grad.sign().to(device=module.corr_accum.device, dtype=torch.int8) |
| module._accumulate_corr_from_grad_sign(signed, corr_step=step) |
| return |
| if not hasattr(module, "T_accum") or tuple(module.T_accum.shape) != shape: |
| return |
| if hasattr(module, "_T_accum_fp"): |
| delattr(module, "_T_accum_fp") |
| signed = raw_grad.sign().to(device=module.T_accum.device, dtype=torch.int8) |
| module.T_accum.copy_( |
| torch.clamp( |
| module.T_accum.to(torch.int16) - signed.to(torch.int16) * step, |
| -127, |
| 127, |
| ).to(torch.int8) |
| ) |
|
|
| def _apply_regular_ternary_hooks(self, accum_threshold, update_scales, t_step, loss_signal): |
| for module in self.modules(): |
| is_bigint = hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign") |
| is_legacy = hasattr(module, "T_accum") or hasattr(module, "E_accum") |
| if is_bigint or is_legacy: |
| self._prepare_per_group_threshold(module) |
| streamed = bool(getattr(module, "_streamed_ternary_backward", False)) |
| has_hook = ( |
| hasattr(module, "_hook_grad_T_sign") |
| or (hasattr(module, "_hook_grad_2d") and hasattr(module, "_hook_x_2d")) |
| or (hasattr(module, "_hook_sparse_indices") and hasattr(module, "_hook_sparse_grad_sign")) |
| ) |
| bigint_streamed = bool(getattr(module, "_streamed_bigint_backward", False)) |
| if (streamed or bigint_streamed) and not has_hook: |
| if streamed and update_scales: |
| self._step_E_from_accum(module) |
| if streamed: |
| had_flip = self._apply_accumulated_flips(module, accum_threshold=accum_threshold) |
| self._record_flip_health(module, had_flip) |
| if hasattr(module, "per_group_threshold"): |
| del module.per_group_threshold |
| continue |
| if has_hook: |
| if hasattr(module, "_hook_grad_T_sign") and hasattr(module, "_accumulate_corr_from_grad_sign"): |
| module._accumulate_corr_from_grad_sign(module._hook_grad_T_sign) |
| del module._hook_grad_T_sign |
| if hasattr(module, "ternary_step"): |
| module.ternary_step(accum_threshold=accum_threshold) |
| if hasattr(module, "per_group_threshold"): |
| del module.per_group_threshold |
|
|
| def _prepare_per_group_threshold(self, module): |
| if self._is_large_sparse_embedding(module): |
| module.per_group_threshold = None |
| return |
| if hasattr(module, "corr_accum") and not hasattr(module, "T_accum"): |
| module.per_group_threshold = None |
| return |
| if not hasattr(module, "E") or not hasattr(module, "_T_shape"): |
| module.per_group_threshold = None |
| return |
| shape = tuple(int(x) for x in module._T_shape.tolist()) |
| out_dim, in_dim = shape |
| gpr = _ceil_div(in_dim, module.group_size) |
| E_view = module.E.view(out_dim, gpr).float() |
| threshold_g = 8.0 + 0.25 * torch.min(E_view.abs(), torch.tensor(32.0, device=E_view.device)) |
| module.per_group_threshold = torch.clamp(threshold_g, max=16.0).to(torch.int8).reshape(-1) |
|
|
| @staticmethod |
| def _is_large_sparse_embedding(module): |
| return ( |
| hasattr(module, "num_embeddings") |
| and hasattr(module, "sparse_threshold") |
| and module.num_embeddings >= module.sparse_threshold |
| ) |
|
|
| @staticmethod |
| def _step_E_from_accum(module): |
| if hasattr(module, "corr_accum"): |
| return |
| if not hasattr(module, "E") or not hasattr(module, "E_accum"): |
| return |
| threshold = int(getattr(module, "_e_accum_threshold", 8)) |
| accum = module.E_accum.to(torch.int16) |
| step = torch.where( |
| accum >= threshold, |
| torch.ones_like(accum, dtype=torch.int16), |
| torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)), |
| ) |
| if step.any(): |
| module.E = torch.clamp(module.E.to(torch.int16) + step, -128, 127).to(torch.int8) |
| module.E_accum = (accum - step * threshold).to(torch.int8) |
|
|
| @staticmethod |
| def _apply_accumulated_flips(module, accum_threshold=3): |
| """Packed-byte carry: when T_accum crosses ±1, move trit by ±1 via ±3^pos.""" |
| if not hasattr(module, "T_accum") or not hasattr(module, "T_packed") or not hasattr(module, "_T_shape"): |
| return False |
| shape = tuple(int(x) for x in module._T_shape.tolist()) |
| if tuple(module.T_accum.shape) != shape: |
| return False |
| carry_up = module.T_accum > 1 |
| carry_down = module.T_accum < -1 |
| if not carry_up.any() and not carry_down.any(): |
| return False |
| dev = module.T_packed.device |
| out_dim, in_dim = shape |
| pows = torch.tensor([1, 3, 9, 27, 81], device=dev, dtype=torch.int16) |
| pk = module.T_packed.to(torch.int16).clone() |
| for p in range(5): |
| if p >= in_dim: |
| continue |
| cols = torch.arange(p, in_dim, 5, device=dev) |
| if cols.numel() == 0: |
| continue |
| is_up = carry_up[:, cols] |
| is_dn = carry_down[:, cols] |
| if not is_up.any() and not is_dn.any(): |
| continue |
| rows_2d = torch.arange(out_dim, device=dev)[:, None] |
| lin_idx = rows_2d * in_dim + cols[None, :] |
| byte_idx = lin_idx // 5 |
| pv = pk[byte_idx] |
| p_up = (pv + pows[p]).clamp(0, 242) |
| p_dn = (pv - pows[p]).clamp(0, 242) |
| pk[byte_idx] = torch.where(is_up, p_up, torch.where(is_dn, p_dn, pv)) |
| module.T_packed = pk.to(torch.uint8) |
| |
| mask = carry_up | carry_down |
| module.T_accum[mask] = torch.zeros_like(module.T_accum[mask]) |
| return True |
|
|
| @staticmethod |
| def _record_flip_health(module, had_flip): |
| if not hasattr(module, "T_accum"): |
| return |
| steps_since = getattr(module, "_steps_since_flip", 0) |
| module._steps_since_flip = 0 if had_flip else steps_since + 1 |
| module._had_flip = False |
|
|
| def generate(self, idx, max_new_token, temperature=1.0, images=None, audio=None, |
| conversation_id=None, top_k=None, min_new_tokens=0, return_metadata=False): |
| if self.kv_ledger is not None and self.kv_ledger.size == 0: |
| with torch.no_grad(): |
| for token_id in idx.reshape(-1).tolist(): |
| self.kv_ledger.append(int(token_id)) |
| self.kq_cache.append(int(token_id)) |
| for i in range(max_new_token): |
| idx_cond = idx[:, -CTX:] |
| logits, _, _, _ = self(idx_cond, images=images, audio=audio, timestep=i, output_mode="text") |
| last_logits = logits[:, -1, :] / temperature |
| |
| if top_k is not None and top_k > 0: |
| v, _ = torch.topk(last_logits, min(top_k, last_logits.size(-1))) |
| kth = v[:, -1].unsqueeze(-1).expand_as(last_logits) |
| last_logits = last_logits.where(last_logits >= kth, float('-inf')) |
| probs = F.softmax(last_logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, idx_next], dim=1) |
| |
| generated = idx.shape[1] - (min_new_tokens if return_metadata else 0) |
| if return_metadata: |
| return { |
| "tokens": idx, |
| "n_generated": generated, |
| "temperature": temperature, |
| } |
| return idx |
|
|