""" generation_utils.py — High-level generation helpers for SeqCond models. These functions wrap SeqCondForCausalLM.generate() / generate_batch() with a more user-friendly interface that handles tokenization, formatting, and streaming. Example usage: from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("path/to/model", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("path/to/model", trust_remote_code=True) model.eval().cuda() text = generate(model, tokenizer, "What is 2 + 2?") print(text) # Batched texts = generate_batch(model, tokenizer, ["What is 2+2?", "Name a planet."]) """ from typing import Iterator, List, Optional import torch import torch.nn.functional as F _SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] # power-of-2 for CUDA graphs def _quantized_seq_len(pos: int) -> int: needed = pos + 1 for s in _SEQ_LENS: if s >= needed: return s return _SEQ_LENS[-1] @torch.no_grad() def generate( model, tokenizer, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.0, use_chat_template: bool = True, use_triton: bool = False, strip_thinking: bool = False, max_thinking_tokens: Optional[int] = None, ) -> str: """ Generate a single completion for *prompt*. Args: model: SeqCondForCausalLM instance. tokenizer: SeqCondTokenizer instance. prompt: Plain-text user prompt. max_new_tokens: Maximum tokens to generate. temperature: Sampling temperature (0 = greedy). top_p: Nucleus sampling probability. top_k: Top-k filtering (0 = disabled). repetition_penalty: Penalty for repeating tokens. use_chat_template: If True, wrap prompt in <|im_start|>user…<|think_start|>. use_triton: If True, use Triton kernels for SeqCond steps. strip_thinking: If True, return only the text after <|think_end|>. max_thinking_tokens: If set, inject <|think_end|> after this many thinking tokens to cap reasoning length. Returns: Generated text (completion only, EOS stripped). """ device = next(model.parameters()).device eos_id = tokenizer.im_end_id think_end_id = tokenizer.think_end_id if use_chat_template: ids = tokenizer.encode_chat(prompt, add_think_start=True) else: ids = tokenizer.encode(prompt) input_ids = torch.tensor([ids], dtype=torch.long, device=device) logits, states = model.model.prefill(input_ids) logits = logits.squeeze(1) generated: List[int] = [] token_buf = torch.zeros((1, 1), dtype=torch.long, device=device) seq_len = len(ids) in_thinking = use_chat_template thinking_tokens = 0 think_end_injected = False counts: dict = {} for _ in range(max_new_tokens): ls = logits[0] / max(temperature, 1e-8) if temperature > 0 else logits[0].clone() if repetition_penalty != 1.0: for t in set(generated): if 0 <= t < model.config.vocab_size: ls[t] /= repetition_penalty if temperature == 0: next_token = int(torch.argmax(ls)) else: if top_k > 0: kth = torch.topk(ls, top_k).values[-1] ls = ls.masked_fill(ls < kth, float("-inf")) if top_p < 1.0: sorted_ls, sorted_idx = torch.sort(ls, descending=True) cum = torch.cumsum(F.softmax(sorted_ls, dim=-1), dim=-1) remove = cum > top_p remove[1:] = remove[:-1].clone(); remove[0] = False ls[sorted_idx[remove]] = float("-inf") probs = F.softmax(ls, dim=-1) next_token = int(torch.multinomial(probs, 1)) # Thinking budget if next_token == think_end_id: in_thinking = False if in_thinking: thinking_tokens += 1 if ( max_thinking_tokens is not None and in_thinking and thinking_tokens >= max_thinking_tokens and not think_end_injected ): next_token = think_end_id in_thinking = False think_end_injected = True generated.append(next_token) if next_token == eos_id: break token_buf[0, 0] = next_token seq_len += 1 logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton) # Decode if generated and generated[-1] == eos_id: generated = generated[:-1] text = tokenizer.decode(generated) if strip_thinking and "<|think_end|>" in text: text = text.split("<|think_end|>", 1)[1].strip() return text @torch.no_grad() def generate_batch( model, tokenizer, prompts: List[str], max_new_tokens: int = 512, temperature: float = 0.7, use_chat_template: bool = True, use_triton: bool = False, strip_thinking: bool = False, ) -> List[str]: """ Batched generation for a list of prompts. Each prompt is prefilled individually (no padding noise), then all sequences are decoded in lockstep with per-sample early stopping. Returns a list of completion strings (EOS stripped). """ device = next(model.parameters()).device eos_id = tokenizer.im_end_id B = len(prompts) if use_chat_template: all_ids = [tokenizer.encode_chat(p, add_think_start=True) for p in prompts] else: all_ids = [tokenizer.encode(p) for p in prompts] # Individual prefills all_logits, all_states = [], [] for ids in all_ids: inp = torch.tensor([ids], dtype=torch.long, device=device) lg, st = model.model.prefill(inp) all_logits.append(lg.squeeze(1)) all_states.append(st) logits = torch.cat(all_logits, dim=0) num_blocks = len(all_states[0]) states = [ tuple(torch.cat([s[i][j] for s in all_states], dim=0) for j in range(len(all_states[0][i]))) for i in range(num_blocks) ] generated = [[] for _ in range(B)] finished = [False] * B active_map = list(range(B)) token_buf = torch.zeros((B, 1), dtype=torch.long, device=device) seq_len = max(len(ids) for ids in all_ids) for _ in range(max_new_tokens): B_cur = len(active_map) if B_cur == 0: break if temperature == 0: next_tokens = torch.argmax(logits, dim=-1) else: probs = F.softmax(logits / max(temperature, 1e-8), dim=-1) next_tokens = torch.multinomial(probs, 1).squeeze(-1) newly_done: set = set() for bi in range(B_cur): oi = active_map[bi] tok = int(next_tokens[bi]) generated[oi].append(tok) if tok == eos_id: finished[oi] = True newly_done.add(bi) else: token_buf[bi, 0] = tok if all(finished): break if newly_done: keep = [bi for bi in range(B_cur) if bi not in newly_done] if not keep: break keep_idx = torch.tensor(keep, device=device) token_buf = token_buf[keep_idx].contiguous() states = [tuple(s[keep_idx].contiguous() for s in st) for st in states] logits = logits[keep_idx] active_map = [active_map[bi] for bi in keep] seq_len += 1 logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton) results = [] for toks in generated: if toks and toks[-1] == eos_id: toks = toks[:-1] text = tokenizer.decode(toks) if strip_thinking and "<|think_end|>" in text: text = text.split("<|think_end|>", 1)[1].strip() results.append(text) return results @torch.no_grad() def stream( model, tokenizer, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, use_chat_template: bool = True, use_triton: bool = False, ) -> Iterator[str]: """ Streaming token-by-token generation. Yields decoded text fragments as they are produced. Useful for interactive applications (e.g., a chat interface). Example: for fragment in stream(model, tokenizer, "Explain gravity."): print(fragment, end="", flush=True) """ device = next(model.parameters()).device eos_id = tokenizer.im_end_id if use_chat_template: ids = tokenizer.encode_chat(prompt, add_think_start=True) else: ids = tokenizer.encode(prompt) input_ids = torch.tensor([ids], dtype=torch.long, device=device) logits, states = model.model.prefill(input_ids) logits = logits.squeeze(1) token_buf = torch.zeros((1, 1), dtype=torch.long, device=device) seq_len = len(ids) for _ in range(max_new_tokens): if temperature == 0: next_token = int(torch.argmax(logits[0])) else: probs = F.softmax(logits[0] / max(temperature, 1e-8), dim=-1) next_token = int(torch.multinomial(probs, 1)) if next_token == eos_id: break try: yield tokenizer.decode([next_token]) except Exception: yield "" token_buf[0, 0] = next_token seq_len += 1 logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton)