nautile-370m / generation_utils.py
maxchbx's picture
Upload folder using huggingface_hub
7acd624 verified
"""
generation_utils.py — High-level generation helpers for SeqCond models.
These functions wrap SeqCondForCausalLM.generate() / generate_batch() with a
more user-friendly interface that handles tokenization, formatting, and
streaming.
Example usage:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("path/to/model", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("path/to/model", trust_remote_code=True)
model.eval().cuda()
text = generate(model, tokenizer, "What is 2 + 2?")
print(text)
# Batched
texts = generate_batch(model, tokenizer, ["What is 2+2?", "Name a planet."])
"""
from typing import Iterator, List, Optional
import torch
import torch.nn.functional as F
_SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] # power-of-2 for CUDA graphs
def _quantized_seq_len(pos: int) -> int:
needed = pos + 1
for s in _SEQ_LENS:
if s >= needed:
return s
return _SEQ_LENS[-1]
@torch.no_grad()
def generate(
model,
tokenizer,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.0,
use_chat_template: bool = True,
use_triton: bool = False,
strip_thinking: bool = False,
max_thinking_tokens: Optional[int] = None,
) -> str:
"""
Generate a single completion for *prompt*.
Args:
model: SeqCondForCausalLM instance.
tokenizer: SeqCondTokenizer instance.
prompt: Plain-text user prompt.
max_new_tokens: Maximum tokens to generate.
temperature: Sampling temperature (0 = greedy).
top_p: Nucleus sampling probability.
top_k: Top-k filtering (0 = disabled).
repetition_penalty: Penalty for repeating tokens.
use_chat_template: If True, wrap prompt in <|im_start|>user…<|think_start|>.
use_triton: If True, use Triton kernels for SeqCond steps.
strip_thinking: If True, return only the text after <|think_end|>.
max_thinking_tokens: If set, inject <|think_end|> after this many
thinking tokens to cap reasoning length.
Returns:
Generated text (completion only, EOS stripped).
"""
device = next(model.parameters()).device
eos_id = tokenizer.im_end_id
think_end_id = tokenizer.think_end_id
if use_chat_template:
ids = tokenizer.encode_chat(prompt, add_think_start=True)
else:
ids = tokenizer.encode(prompt)
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
logits, states = model.model.prefill(input_ids)
logits = logits.squeeze(1)
generated: List[int] = []
token_buf = torch.zeros((1, 1), dtype=torch.long, device=device)
seq_len = len(ids)
in_thinking = use_chat_template
thinking_tokens = 0
think_end_injected = False
counts: dict = {}
for _ in range(max_new_tokens):
ls = logits[0] / max(temperature, 1e-8) if temperature > 0 else logits[0].clone()
if repetition_penalty != 1.0:
for t in set(generated):
if 0 <= t < model.config.vocab_size:
ls[t] /= repetition_penalty
if temperature == 0:
next_token = int(torch.argmax(ls))
else:
if top_k > 0:
kth = torch.topk(ls, top_k).values[-1]
ls = ls.masked_fill(ls < kth, float("-inf"))
if top_p < 1.0:
sorted_ls, sorted_idx = torch.sort(ls, descending=True)
cum = torch.cumsum(F.softmax(sorted_ls, dim=-1), dim=-1)
remove = cum > top_p
remove[1:] = remove[:-1].clone(); remove[0] = False
ls[sorted_idx[remove]] = float("-inf")
probs = F.softmax(ls, dim=-1)
next_token = int(torch.multinomial(probs, 1))
# Thinking budget
if next_token == think_end_id:
in_thinking = False
if in_thinking:
thinking_tokens += 1
if (
max_thinking_tokens is not None
and in_thinking
and thinking_tokens >= max_thinking_tokens
and not think_end_injected
):
next_token = think_end_id
in_thinking = False
think_end_injected = True
generated.append(next_token)
if next_token == eos_id:
break
token_buf[0, 0] = next_token
seq_len += 1
logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)
# Decode
if generated and generated[-1] == eos_id:
generated = generated[:-1]
text = tokenizer.decode(generated)
if strip_thinking and "<|think_end|>" in text:
text = text.split("<|think_end|>", 1)[1].strip()
return text
@torch.no_grad()
def generate_batch(
model,
tokenizer,
prompts: List[str],
max_new_tokens: int = 512,
temperature: float = 0.7,
use_chat_template: bool = True,
use_triton: bool = False,
strip_thinking: bool = False,
) -> List[str]:
"""
Batched generation for a list of prompts.
Each prompt is prefilled individually (no padding noise), then all
sequences are decoded in lockstep with per-sample early stopping.
Returns a list of completion strings (EOS stripped).
"""
device = next(model.parameters()).device
eos_id = tokenizer.im_end_id
B = len(prompts)
if use_chat_template:
all_ids = [tokenizer.encode_chat(p, add_think_start=True) for p in prompts]
else:
all_ids = [tokenizer.encode(p) for p in prompts]
# Individual prefills
all_logits, all_states = [], []
for ids in all_ids:
inp = torch.tensor([ids], dtype=torch.long, device=device)
lg, st = model.model.prefill(inp)
all_logits.append(lg.squeeze(1))
all_states.append(st)
logits = torch.cat(all_logits, dim=0)
num_blocks = len(all_states[0])
states = [
tuple(torch.cat([s[i][j] for s in all_states], dim=0) for j in range(len(all_states[0][i])))
for i in range(num_blocks)
]
generated = [[] for _ in range(B)]
finished = [False] * B
active_map = list(range(B))
token_buf = torch.zeros((B, 1), dtype=torch.long, device=device)
seq_len = max(len(ids) for ids in all_ids)
for _ in range(max_new_tokens):
B_cur = len(active_map)
if B_cur == 0:
break
if temperature == 0:
next_tokens = torch.argmax(logits, dim=-1)
else:
probs = F.softmax(logits / max(temperature, 1e-8), dim=-1)
next_tokens = torch.multinomial(probs, 1).squeeze(-1)
newly_done: set = set()
for bi in range(B_cur):
oi = active_map[bi]
tok = int(next_tokens[bi])
generated[oi].append(tok)
if tok == eos_id:
finished[oi] = True
newly_done.add(bi)
else:
token_buf[bi, 0] = tok
if all(finished):
break
if newly_done:
keep = [bi for bi in range(B_cur) if bi not in newly_done]
if not keep:
break
keep_idx = torch.tensor(keep, device=device)
token_buf = token_buf[keep_idx].contiguous()
states = [tuple(s[keep_idx].contiguous() for s in st) for st in states]
logits = logits[keep_idx]
active_map = [active_map[bi] for bi in keep]
seq_len += 1
logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)
results = []
for toks in generated:
if toks and toks[-1] == eos_id:
toks = toks[:-1]
text = tokenizer.decode(toks)
if strip_thinking and "<|think_end|>" in text:
text = text.split("<|think_end|>", 1)[1].strip()
results.append(text)
return results
@torch.no_grad()
def stream(
model,
tokenizer,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
use_chat_template: bool = True,
use_triton: bool = False,
) -> Iterator[str]:
"""
Streaming token-by-token generation.
Yields decoded text fragments as they are produced. Useful for interactive
applications (e.g., a chat interface).
Example:
for fragment in stream(model, tokenizer, "Explain gravity."):
print(fragment, end="", flush=True)
"""
device = next(model.parameters()).device
eos_id = tokenizer.im_end_id
if use_chat_template:
ids = tokenizer.encode_chat(prompt, add_think_start=True)
else:
ids = tokenizer.encode(prompt)
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
logits, states = model.model.prefill(input_ids)
logits = logits.squeeze(1)
token_buf = torch.zeros((1, 1), dtype=torch.long, device=device)
seq_len = len(ids)
for _ in range(max_new_tokens):
if temperature == 0:
next_token = int(torch.argmax(logits[0]))
else:
probs = F.softmax(logits[0] / max(temperature, 1e-8), dim=-1)
next_token = int(torch.multinomial(probs, 1))
if next_token == eos_id:
break
try:
yield tokenizer.decode([next_token])
except Exception:
yield ""
token_buf[0, 0] = next_token
seq_len += 1
logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)