ARBS / arbitor /main.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""ARB — Any Relational Bit. Core model assembly."""
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import ceil as _ceil
_ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
from .config import VOCAB, HIDDEN_DIM, SPECIAL_VOCAB, CTX, THRESHOLD, CODEBOOK_DIM, CODEBOOK_SIZE, KV_LEDGER_SIZE, KQ_CACHE_SIZE, MEMGRAM_STRUCT_PRIMES, MEMGRAM_CONV_PRIMES, MEMGRAM_EMBED_DIM, MEMGRAM_KEY_DIM, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, K_MAX_COMPOSITES, MG_TOP_K
from .kernel.ternary_scale import TScaleType, TernaryScaleTensor, TernaryRMSNorm, _HAS_TRITON
try:
from .kernel.ternary_scale import _triton_apply_accumulated_flips
except ImportError:
_triton_apply_accumulated_flips = None
from .converters.convert_to_ternary8 import pack_ternary
try:
from .kernel.ternary_scale import _TritonTernaryEmbedFn
except ImportError:
_TritonTernaryEmbedFn = None
from .sequencers import ByteEmbedding, MultimodalSequencer
from .vq import SharedVQ
from .components import (
ByteHead, OutputRouter,
MemGram, LossComponents, LossWeights,
CompositeProposalHead, MoEGraph,
)
from .decoders import VideoHead, TalkerHead
from .components import _BOUNDARY_TOKEN_MAP as _BOUNDARY_MAP
from .attention import KVLedger, KQCache, ContextAttentionScheduler
from .kernel.flash_vq import FlashVQCodebook
def _extract_boundary_from_input(x):
if x.dim() != 2:
return None
first_token = x[0, 0].item()
if first_token in _BOUNDARY_MAP:
return first_token
for tok in x[0].tolist():
if tok in _BOUNDARY_MAP:
return tok
return None
class ARBModel(nn.Module):
def __init__(self, tscale_type=TScaleType.T32, threshold=THRESHOLD,
max_graph_hops=4, max_moe_iters=4, halt_threshold=0.99,
enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True,
enable_memory_modules=False, enable_moe=True,
shared_vq_size=None, kgvq_codebook_size=None,
enable_attention=True, enable_output_router=True,
enable_video_output=True, enable_talker_output=True):
super().__init__()
self.image_enabled = enable_image
self.audio_enabled = enable_audio
self.embedding = ByteEmbedding(tscale_type=tscale_type)
self.multimodal_sequencer = MultimodalSequencer(
tscale_type=tscale_type,
enable_text=True, enable_image=enable_image, enable_audio=enable_audio,
)
self.text_sequencer = self.multimodal_sequencer.text
self.image_sequencer = self.multimodal_sequencer.image
self.audio_sequencer = self.multimodal_sequencer.audio
self.vq_enabled = enable_vq
self.bridge = SharedVQ(
codebook_size=shared_vq_size,
tscale_type=tscale_type, enable_image=enable_image, enable_audio=enable_audio,
) if enable_vq else None
self.vq_to_trigram = TernaryScaleTensor(CODEBOOK_DIM, HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None
self.vq_to_trigram_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) if enable_vq else None
self.graph_enabled = enable_graph and enable_vq
graph_vocab_size = self.bridge.total_codebook_size if self.graph_enabled else None
self.threshold = threshold
self.moegraph = MoEGraph(
trigram_dim=HIDDEN_DIM, codebook_size=graph_vocab_size or CODEBOOK_SIZE,
max_iters=max_moe_iters, halt_threshold=halt_threshold,
top_k=MG_TOP_K,
) if self.graph_enabled else None
self.byte_head = ByteHead(tscale_type=tscale_type)
# Composite motif generation (Phase 17)
self.composite_head = CompositeProposalHead(
dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM,
k_max=K_MAX_COMPOSITES, codebook_size=kgvq_codebook_size or KGVQ_CODEBOOK_SIZE,
tscale_type=tscale_type,
) if self.graph_enabled else None
self.output_router = OutputRouter(tscale_type=tscale_type, depth=3) if enable_output_router else None
self.video_head = VideoHead(tscale_type=tscale_type) if enable_video_output else None
self.talker_head = TalkerHead(tscale_type=tscale_type) if enable_talker_output else None
self.memgram = MemGram(
struct_primes=MEMGRAM_STRUCT_PRIMES,
conv_primes=MEMGRAM_CONV_PRIMES,
embed_dim=MEMGRAM_EMBED_DIM, key_dim=MEMGRAM_KEY_DIM, hidden_dim=HIDDEN_DIM,
) if enable_memory_modules else None
self.memgram_enabled = self.memgram is not None
# KV Ledger + Attention (Phase 16 — replaces LSTM)
self.kv_ledger = KVLedger(max_size=KV_LEDGER_SIZE) if enable_attention else None
self.kq_cache = KQCache(max_size=KQ_CACHE_SIZE) if enable_attention else None
self.attention = ContextAttentionScheduler(dim=HIDDEN_DIM) if enable_attention else None
self.attention_enabled = bool(enable_attention)
def forward(self, x, targets=None, commitment_warmup_weight=1.0,
act_warmup_mode=False, ponder_lambda=0.01, images=None,
audio=None, timestep=0, loss_weights=None, output_mode=None):
has_image = images is not None
has_audio = audio is not None
if has_image and (not self.image_enabled or self.image_sequencer is None):
raise ValueError("images provided but model has enable_image=False")
if has_audio and (not self.audio_enabled or self.audio_sequencer is None):
raise ValueError("audio provided but model has enable_audio=False")
embedded = self.embedding(x)
seq_inputs = {'text': embedded}
if has_image:
seq_inputs['image'] = images
if has_audio:
seq_inputs['audio'] = audio
seq_outputs = self.multimodal_sequencer(seq_inputs)
relational = seq_outputs['text']
indices_dict = {}
if self.vq_enabled:
bridge_inputs = {'text': relational}
if 'image' in seq_outputs:
bridge_inputs['image'] = seq_outputs['image']
if 'audio' in seq_outputs:
bridge_inputs['audio'] = seq_outputs['audio']
combined, vq_losses, indices_dict = self.bridge(bridge_inputs, timestep=timestep)
if combined is None:
combined = relational
elif combined.shape[-1] == CODEBOOK_DIM:
combined = self.vq_to_trigram_norm(self.vq_to_trigram(combined))
vq_loss = vq_losses.get('text_vq', torch.zeros((), device=x.device))
if 'image_vq' in vq_losses:
vq_loss = vq_loss + vq_losses['image_vq']
if 'audio_vq' in vq_losses:
vq_loss = vq_loss + vq_losses['audio_vq']
else:
combined = relational
vq_loss = torch.zeros((), device=x.device)
active_mods = ['text']
if has_image:
active_mods.append('image')
if has_audio:
active_mods.append('audio')
active_count = len(active_mods)
# MemGram injection (after VQ, before Graph — D92)
memgram_decay_reg = torch.tensor(0.0, device=x.device)
if self.memgram_enabled and self.memgram is not None and self.vq_enabled:
vq_indices = indices_dict.get('text', torch.zeros(combined.shape[0], combined.shape[1], dtype=torch.long, device=x.device))
combined = self.memgram(
vq_indices=vq_indices,
hidden_state=combined,
)
all_indices = None
composite_ids = None
composite_vq_loss = None
processed = combined
moegraph_ponder_loss = torch.tensor(0.0, device=x.device)
if self.graph_enabled and self.moegraph is not None and self.vq_enabled and vq_loss is not None:
self.moegraph._codebook_table = self.bridge.vq.table
self.moegraph._codebook_embed = None
all_indices = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long))
if has_image and 'image' in indices_dict:
all_indices = torch.cat([all_indices, indices_dict['image']], dim=1)
if has_audio and 'audio' in indices_dict:
all_indices = torch.cat([all_indices, indices_dict['audio']], dim=1)
# MemGram retrieval for MoEGraph injection
memgram_cb = None
if self.memgram_enabled and self.memgram is not None and self.vq_enabled:
vq_idx = indices_dict.get('text', combined.new_zeros(combined.shape[0], combined.shape[1], dtype=torch.long))
memgram_cb = self.memgram.retrieve_cb(vq_idx)
# Attention output for KV conditioning
attn_out = None
if self.attention_enabled and self.attention is not None and self.kv_ledger is not None:
attn_out = self.attention(combined, self.kv_ledger, kq_cache=self.kq_cache)
# MoEGraph forward (unified ACT loop)
processed, moegraph_ponder_loss = self.moegraph(
combined, all_indices,
attention_output=attn_out,
memgram_cb_output=memgram_cb,
threshold=self.threshold,
)
# Composite motif generation (Phase 17)
if self.composite_head is not None:
composite_ids, composite_vq_loss, _ = self.composite_head(processed.mean(dim=1))
# Update bounded int-only KG co-occurrence state.
self.moegraph.update_kg_edges(all_indices)
# OutputRouter: route to appropriate head
if targets is not None or output_mode == "text":
logits = self.byte_head(processed)
elif output_mode == "video":
if self.video_head is None:
raise ValueError("output_mode='video' requested but video output is disabled")
logits = self.video_head(processed)
elif output_mode in {"audio", "talker"}:
if self.talker_head is None:
raise ValueError("audio/talker output requested but talker output is disabled")
logits = self.talker_head(processed)
elif self.training and self.output_router is not None:
route = self.output_router(processed, training=True)
route_weights, route_logits = route
logits = self.byte_head(processed)
elif self.output_router is not None:
route = self.output_router(processed, training=False)
if isinstance(route, torch.Tensor) and route.numel() > 0:
use_video = (route == 2).any() and self.video_head is not None
use_talk = (route == 3).any() and self.talker_head is not None
logits = self.video_head(processed) if use_video else \
self.talker_head(processed) if use_talk else \
self.byte_head(processed)
else:
logits = self.byte_head(processed)
else:
logits = self.byte_head(processed)
T_text = relational.shape[1]
if logits.dim() == 3 and logits.shape[-1] == VOCAB:
logits = logits[:, :T_text, :]
with torch.no_grad():
self._append_predictions_to_kv(logits.argmax(dim=-1), composite_ids=composite_ids)
losses = None
if targets is not None:
next_byte_logits = logits[:, :-1, :].contiguous()
lm_loss = F.cross_entropy(
next_byte_logits.view(-1, VOCAB),
targets.contiguous().view(-1),
ignore_index=SPECIAL_VOCAB["PAD"]
)
vq_component = commitment_warmup_weight * vq_loss if self.vq_enabled else None
losses = LossComponents(
lm=lm_loss,
vq_commitment=vq_component,
graph_l1=None,
moegraph_ponder=moegraph_ponder_loss,
memgram_decay_reg=memgram_decay_reg if self.memgram_enabled else None,
composite_vq=composite_vq_loss if self.composite_head is not None and composite_ids is not None else None,
weights=loss_weights if loss_weights is not None else LossWeights(),
)
return logits, losses, all_indices, None
@torch.no_grad()
def _append_predictions_to_kv(self, pred_ids, composite_ids=None):
if self.kv_ledger is None or self.kq_cache is None:
return
for b in range(pred_ids.shape[0]):
for t in range(pred_ids.shape[1]):
token_id = int(pred_ids[b, t])
self.kv_ledger.append(token_id)
self.kq_cache.append(token_id)
if composite_ids is None:
continue
composite_offset = self.bridge.total_codebook_size if self.vq_enabled and self.bridge is not None else 0
for k in range(composite_ids.shape[1]):
cid = int(composite_ids[b, k])
if cid >= 0:
self.kv_ledger.append(composite_offset + cid)
def _ternary_update_memory(self, accum_threshold=8, update_scales=True,
loss_components=None, loss_signal=None):
signal = loss_components.total if loss_components is not None else loss_signal
t_step = self._ternary_t_step(signal)
if signal is not None and not torch.isfinite(signal.detach()).all():
warnings.warn("Non-finite loss detected — skipping ternary state update",
RuntimeWarning, stacklevel=2)
self._clear_ternary_hooks()
self.zero_grad(set_to_none=True)
return
if loss_components is not None:
self._componentwise_ternary_backward(loss_components, t_step, update_scales, accum_threshold)
else:
self._apply_regular_ternary_hooks(accum_threshold, update_scales, t_step, loss_signal)
self._clear_ternary_hooks()
self._clear_backward_update_flags()
def prepare_ternary_backward(self, loss_signal=None, update_scales=True):
"""Configure streaming CUDA ternary updates before `loss.backward()`.
BigInt-scaled dense linear backward accumulates directly into int64
`corr_accum`, while legacy sparse tables still use int8 `T_accum`.
Calling this before backward lets the streaming path use the same
loss-scaled step that `_ternary_update_memory()` will finalize.
"""
t_step = self._ternary_t_step(loss_signal)
for module in self.modules():
if hasattr(module, "T_accum") or hasattr(module, "corr_accum"):
module._backward_t_accum_step = t_step
module._backward_update_scales = bool(update_scales)
module._stream_backward_updates = True
def _clear_backward_update_flags(self):
for module in self.modules():
for attr in (
"_backward_t_accum_step",
"_backward_update_scales",
"_stream_backward_updates",
"_streamed_ternary_backward",
"_streamed_bigint_backward",
):
if hasattr(module, attr):
delattr(module, attr)
@staticmethod
def _ternary_t_step(loss_signal):
return 1
def _clear_ternary_hooks(self):
base_names = [
"_hook_grad_T_sign", "_hook_grad_2d", "_hook_x_2d", "_hook_T",
"_hook_sparse_indices", "_hook_sparse_grad_sign", "_hook_sparse_T",
]
for module in self.modules():
if hasattr(module, "_T_accum_fp"):
delattr(module, "_T_accum_fp")
for hook_name in base_names:
if hasattr(module, hook_name):
delattr(module, hook_name)
for hook_name in list(vars(module).keys()):
if hook_name.startswith((
"_hook_grad_T_sign_", "_hook_grad_2d_", "_hook_x_2d_", "_hook_T_",
"_hook_sparse_indices_", "_hook_sparse_grad_sign_", "_hook_sparse_T_",
)):
delattr(module, hook_name)
def _componentwise_ternary_backward(self, loss_components, t_step, update_scales, accum_threshold):
from arbitor.kernel.ternary_scale import _COMPONENT_CONTEXT
self.prepare_ternary_backward(loss_components.total, update_scales=update_scales)
active = [(n, t, w) for n, t, w in loss_components.active_fields
if t is not None and t.dim() == 0 and t.requires_grad and float(w) != 0.0]
for idx, (name, comp_tensor, weight) in enumerate(active):
retain = idx < len(active) - 1
_COMPONENT_CONTEXT.set(name, weight)
try:
comp_tensor.backward(retain_graph=retain)
finally:
_COMPONENT_CONTEXT.clear()
self._consume_component_hooks(name, weight, t_step, update_scales, accum_threshold)
with torch.no_grad():
for module in self.modules():
if self._is_large_sparse_embedding(module):
continue
if update_scales:
self._step_E_from_accum(module)
self._apply_accumulated_flips(module, accum_threshold=accum_threshold)
def _consume_component_hooks(self, name, weight, t_step, update_scales, accum_threshold):
for module in self.modules():
sparse_idx_key = f"_hook_sparse_indices_{name}"
sparse_grad_key = f"_hook_sparse_grad_sign_{name}"
sparse_t_key = f"_hook_sparse_T_{name}"
if hasattr(module, sparse_idx_key) and hasattr(module, sparse_grad_key):
setattr(module, "_hook_sparse_indices", getattr(module, sparse_idx_key))
setattr(module, "_hook_sparse_grad_sign", getattr(module, sparse_grad_key))
if hasattr(module, sparse_t_key):
setattr(module, "_hook_sparse_T", getattr(module, sparse_t_key))
if update_scales and hasattr(module, "update_E"):
module._e_accum_threshold = 8
module.update_E()
if hasattr(module, "T_accum"):
module._t_accum_step = max(1, int(round(abs(float(weight)) * t_step)))
if hasattr(module, "ternary_step"):
module.ternary_step(accum_threshold=accum_threshold)
for key in (sparse_idx_key, sparse_grad_key, sparse_t_key):
if hasattr(module, key):
delattr(module, key)
continue
dense_key = f"_hook_grad_T_sign_{name}"
dense_t_key = f"_hook_T_{name}"
if hasattr(module, dense_key):
grad_sign = getattr(module, dense_key)
hook_t = getattr(module, dense_t_key, None)
self._accumulate_component_grad_continuous(
module, grad_sign, weight, t_step,
)
delattr(module, dense_key)
if hasattr(module, dense_t_key):
delattr(module, dense_t_key)
grad_key = f"_hook_grad_2d_{name}"
x_key = f"_hook_x_2d_{name}"
if not hasattr(module, grad_key) or not hasattr(module, x_key):
continue
comp_grad = getattr(module, grad_key)
comp_x = getattr(module, x_key)
if torch.isfinite(comp_grad).all() and torch.isfinite(comp_x).all():
raw_grad = torch.clamp(comp_grad.transpose(0, 1) @ comp_x, -10.0, 10.0)
self._accumulate_component_grad_continuous(
module, raw_grad, weight, t_step,
)
delattr(module, grad_key)
delattr(module, x_key)
def _accumulate_component_grad_continuous(self, module, raw_grad, weight, t_step):
"""Component loss accumulation without persistent float optimizer state."""
if not hasattr(module, "_T_shape"):
return
shape = tuple(int(x) for x in module._T_shape.tolist())
if tuple(raw_grad.shape) != shape:
return
with torch.no_grad():
step = max(1, int(round(abs(float(weight)) * t_step)))
if float(weight) < 0:
step = -step
if hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign"):
signed = raw_grad.sign().to(device=module.corr_accum.device, dtype=torch.int8)
module._accumulate_corr_from_grad_sign(signed, corr_step=step)
return
if not hasattr(module, "T_accum") or tuple(module.T_accum.shape) != shape:
return
if hasattr(module, "_T_accum_fp"):
delattr(module, "_T_accum_fp")
signed = raw_grad.sign().to(device=module.T_accum.device, dtype=torch.int8)
module.T_accum.copy_(
torch.clamp(
module.T_accum.to(torch.int16) - signed.to(torch.int16) * step,
-127,
127,
).to(torch.int8)
)
def _apply_regular_ternary_hooks(self, accum_threshold, update_scales, t_step, loss_signal):
for module in self.modules():
is_bigint = hasattr(module, "corr_accum") and hasattr(module, "_accumulate_corr_from_grad_sign")
is_legacy = hasattr(module, "T_accum") or hasattr(module, "E_accum")
if is_bigint or is_legacy:
self._prepare_per_group_threshold(module)
streamed = bool(getattr(module, "_streamed_ternary_backward", False))
has_hook = (
hasattr(module, "_hook_grad_T_sign")
or (hasattr(module, "_hook_grad_2d") and hasattr(module, "_hook_x_2d"))
or (hasattr(module, "_hook_sparse_indices") and hasattr(module, "_hook_sparse_grad_sign"))
)
bigint_streamed = bool(getattr(module, "_streamed_bigint_backward", False))
if (streamed or bigint_streamed) and not has_hook:
if streamed and update_scales:
self._step_E_from_accum(module)
if streamed:
had_flip = self._apply_accumulated_flips(module, accum_threshold=accum_threshold)
self._record_flip_health(module, had_flip)
if hasattr(module, "per_group_threshold"):
del module.per_group_threshold
continue
if has_hook:
if hasattr(module, "_hook_grad_T_sign") and hasattr(module, "_accumulate_corr_from_grad_sign"):
module._accumulate_corr_from_grad_sign(module._hook_grad_T_sign)
del module._hook_grad_T_sign
if hasattr(module, "ternary_step"):
module.ternary_step(accum_threshold=accum_threshold)
if hasattr(module, "per_group_threshold"):
del module.per_group_threshold
def _prepare_per_group_threshold(self, module):
if self._is_large_sparse_embedding(module):
module.per_group_threshold = None
return
if hasattr(module, "corr_accum") and not hasattr(module, "T_accum"):
module.per_group_threshold = None
return
if not hasattr(module, "E") or not hasattr(module, "_T_shape"):
module.per_group_threshold = None
return
shape = tuple(int(x) for x in module._T_shape.tolist())
out_dim, in_dim = shape
gpr = _ceil_div(in_dim, module.group_size)
E_view = module.E.view(out_dim, gpr).float()
threshold_g = 8.0 + 0.25 * torch.min(E_view.abs(), torch.tensor(32.0, device=E_view.device))
module.per_group_threshold = torch.clamp(threshold_g, max=16.0).to(torch.int8).reshape(-1)
@staticmethod
def _is_large_sparse_embedding(module):
return (
hasattr(module, "num_embeddings")
and hasattr(module, "sparse_threshold")
and module.num_embeddings >= module.sparse_threshold
)
@staticmethod
def _step_E_from_accum(module):
if hasattr(module, "corr_accum"):
return # BigInt modules don't use E_accum threshold flips
if not hasattr(module, "E") or not hasattr(module, "E_accum"):
return
threshold = int(getattr(module, "_e_accum_threshold", 8))
accum = module.E_accum.to(torch.int16)
step = torch.where(
accum >= threshold,
torch.ones_like(accum, dtype=torch.int16),
torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)),
)
if step.any():
module.E = torch.clamp(module.E.to(torch.int16) + step, -128, 127).to(torch.int8)
module.E_accum = (accum - step * threshold).to(torch.int8)
@staticmethod
def _apply_accumulated_flips(module, accum_threshold=3):
"""Packed-byte carry: when T_accum crosses ±1, move trit by ±1 via ±3^pos."""
if not hasattr(module, "T_accum") or not hasattr(module, "T_packed") or not hasattr(module, "_T_shape"):
return False
shape = tuple(int(x) for x in module._T_shape.tolist())
if tuple(module.T_accum.shape) != shape:
return False
carry_up = module.T_accum > 1
carry_down = module.T_accum < -1
if not carry_up.any() and not carry_down.any():
return False
dev = module.T_packed.device
out_dim, in_dim = shape
pows = torch.tensor([1, 3, 9, 27, 81], device=dev, dtype=torch.int16)
pk = module.T_packed.to(torch.int16).clone()
for p in range(5):
if p >= in_dim:
continue
cols = torch.arange(p, in_dim, 5, device=dev)
if cols.numel() == 0:
continue
is_up = carry_up[:, cols]
is_dn = carry_down[:, cols]
if not is_up.any() and not is_dn.any():
continue
rows_2d = torch.arange(out_dim, device=dev)[:, None]
lin_idx = rows_2d * in_dim + cols[None, :]
byte_idx = lin_idx // 5
pv = pk[byte_idx]
p_up = (pv + pows[p]).clamp(0, 242)
p_dn = (pv - pows[p]).clamp(0, 242)
pk[byte_idx] = torch.where(is_up, p_up, torch.where(is_dn, p_dn, pv))
module.T_packed = pk.to(torch.uint8)
# Reset T_accum to 0 on carry so W = T_accum × T doesn't jump
mask = carry_up | carry_down
module.T_accum[mask] = torch.zeros_like(module.T_accum[mask])
return True
@staticmethod
def _record_flip_health(module, had_flip):
if not hasattr(module, "T_accum"):
return
steps_since = getattr(module, "_steps_since_flip", 0)
module._steps_since_flip = 0 if had_flip else steps_since + 1
module._had_flip = False
def generate(self, idx, max_new_token, temperature=1.0, images=None, audio=None,
conversation_id=None, top_k=None, min_new_tokens=0, return_metadata=False):
if self.kv_ledger is not None and self.kv_ledger.size == 0:
with torch.no_grad():
for token_id in idx.reshape(-1).tolist():
self.kv_ledger.append(int(token_id))
self.kq_cache.append(int(token_id))
for i in range(max_new_token):
idx_cond = idx[:, -CTX:]
logits, _, _, _ = self(idx_cond, images=images, audio=audio, timestep=i, output_mode="text")
last_logits = logits[:, -1, :] / temperature
# top-k filtering
if top_k is not None and top_k > 0:
v, _ = torch.topk(last_logits, min(top_k, last_logits.size(-1)))
kth = v[:, -1].unsqueeze(-1).expand_as(last_logits)
last_logits = last_logits.where(last_logits >= kth, float('-inf'))
probs = F.softmax(last_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
# Enforce min_new_tokens (only relevant if caller truncates after generation)
generated = idx.shape[1] - (min_new_tokens if return_metadata else 0)
if return_metadata:
return {
"tokens": idx,
"n_generated": generated,
"temperature": temperature,
}
return idx