| """ |
| generation_utils.py — High-level generation helpers for SeqCond models. |
| |
| These functions wrap SeqCondForCausalLM.generate() / generate_batch() with a |
| more user-friendly interface that handles tokenization, formatting, and |
| streaming. |
| |
| Example usage: |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| model = AutoModelForCausalLM.from_pretrained("path/to/model", trust_remote_code=True) |
| tokenizer = AutoTokenizer.from_pretrained("path/to/model", trust_remote_code=True) |
| model.eval().cuda() |
| |
| text = generate(model, tokenizer, "What is 2 + 2?") |
| print(text) |
| |
| # Batched |
| texts = generate_batch(model, tokenizer, ["What is 2+2?", "Name a planet."]) |
| """ |
|
|
| from typing import Iterator, List, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| _SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] |
|
|
|
|
| def _quantized_seq_len(pos: int) -> int: |
| needed = pos + 1 |
| for s in _SEQ_LENS: |
| if s >= needed: |
| return s |
| return _SEQ_LENS[-1] |
|
|
|
|
| @torch.no_grad() |
| def generate( |
| model, |
| tokenizer, |
| prompt: str, |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| top_k: int = 50, |
| repetition_penalty: float = 1.0, |
| use_chat_template: bool = True, |
| use_triton: bool = False, |
| strip_thinking: bool = False, |
| max_thinking_tokens: Optional[int] = None, |
| ) -> str: |
| """ |
| Generate a single completion for *prompt*. |
| |
| Args: |
| model: SeqCondForCausalLM instance. |
| tokenizer: SeqCondTokenizer instance. |
| prompt: Plain-text user prompt. |
| max_new_tokens: Maximum tokens to generate. |
| temperature: Sampling temperature (0 = greedy). |
| top_p: Nucleus sampling probability. |
| top_k: Top-k filtering (0 = disabled). |
| repetition_penalty: Penalty for repeating tokens. |
| use_chat_template: If True, wrap prompt in <|im_start|>user…<|think_start|>. |
| use_triton: If True, use Triton kernels for SeqCond steps. |
| strip_thinking: If True, return only the text after <|think_end|>. |
| max_thinking_tokens: If set, inject <|think_end|> after this many |
| thinking tokens to cap reasoning length. |
| |
| Returns: |
| Generated text (completion only, EOS stripped). |
| """ |
| device = next(model.parameters()).device |
| eos_id = tokenizer.im_end_id |
| think_end_id = tokenizer.think_end_id |
|
|
| if use_chat_template: |
| ids = tokenizer.encode_chat(prompt, add_think_start=True) |
| else: |
| ids = tokenizer.encode(prompt) |
|
|
| input_ids = torch.tensor([ids], dtype=torch.long, device=device) |
| logits, states = model.model.prefill(input_ids) |
| logits = logits.squeeze(1) |
|
|
| generated: List[int] = [] |
| token_buf = torch.zeros((1, 1), dtype=torch.long, device=device) |
| seq_len = len(ids) |
|
|
| in_thinking = use_chat_template |
| thinking_tokens = 0 |
| think_end_injected = False |
| counts: dict = {} |
|
|
| for _ in range(max_new_tokens): |
| ls = logits[0] / max(temperature, 1e-8) if temperature > 0 else logits[0].clone() |
|
|
| if repetition_penalty != 1.0: |
| for t in set(generated): |
| if 0 <= t < model.config.vocab_size: |
| ls[t] /= repetition_penalty |
|
|
| if temperature == 0: |
| next_token = int(torch.argmax(ls)) |
| else: |
| if top_k > 0: |
| kth = torch.topk(ls, top_k).values[-1] |
| ls = ls.masked_fill(ls < kth, float("-inf")) |
| if top_p < 1.0: |
| sorted_ls, sorted_idx = torch.sort(ls, descending=True) |
| cum = torch.cumsum(F.softmax(sorted_ls, dim=-1), dim=-1) |
| remove = cum > top_p |
| remove[1:] = remove[:-1].clone(); remove[0] = False |
| ls[sorted_idx[remove]] = float("-inf") |
| probs = F.softmax(ls, dim=-1) |
| next_token = int(torch.multinomial(probs, 1)) |
|
|
| |
| if next_token == think_end_id: |
| in_thinking = False |
| if in_thinking: |
| thinking_tokens += 1 |
| if ( |
| max_thinking_tokens is not None |
| and in_thinking |
| and thinking_tokens >= max_thinking_tokens |
| and not think_end_injected |
| ): |
| next_token = think_end_id |
| in_thinking = False |
| think_end_injected = True |
|
|
| generated.append(next_token) |
| if next_token == eos_id: |
| break |
|
|
| token_buf[0, 0] = next_token |
| seq_len += 1 |
| logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton) |
|
|
| |
| if generated and generated[-1] == eos_id: |
| generated = generated[:-1] |
|
|
| text = tokenizer.decode(generated) |
| if strip_thinking and "<|think_end|>" in text: |
| text = text.split("<|think_end|>", 1)[1].strip() |
| return text |
|
|
|
|
| @torch.no_grad() |
| def generate_batch( |
| model, |
| tokenizer, |
| prompts: List[str], |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| use_chat_template: bool = True, |
| use_triton: bool = False, |
| strip_thinking: bool = False, |
| ) -> List[str]: |
| """ |
| Batched generation for a list of prompts. |
| |
| Each prompt is prefilled individually (no padding noise), then all |
| sequences are decoded in lockstep with per-sample early stopping. |
| |
| Returns a list of completion strings (EOS stripped). |
| """ |
| device = next(model.parameters()).device |
| eos_id = tokenizer.im_end_id |
| B = len(prompts) |
|
|
| if use_chat_template: |
| all_ids = [tokenizer.encode_chat(p, add_think_start=True) for p in prompts] |
| else: |
| all_ids = [tokenizer.encode(p) for p in prompts] |
|
|
| |
| all_logits, all_states = [], [] |
| for ids in all_ids: |
| inp = torch.tensor([ids], dtype=torch.long, device=device) |
| lg, st = model.model.prefill(inp) |
| all_logits.append(lg.squeeze(1)) |
| all_states.append(st) |
|
|
| logits = torch.cat(all_logits, dim=0) |
| num_blocks = len(all_states[0]) |
| states = [ |
| tuple(torch.cat([s[i][j] for s in all_states], dim=0) for j in range(len(all_states[0][i]))) |
| for i in range(num_blocks) |
| ] |
|
|
| generated = [[] for _ in range(B)] |
| finished = [False] * B |
| active_map = list(range(B)) |
| token_buf = torch.zeros((B, 1), dtype=torch.long, device=device) |
| seq_len = max(len(ids) for ids in all_ids) |
|
|
| for _ in range(max_new_tokens): |
| B_cur = len(active_map) |
| if B_cur == 0: |
| break |
|
|
| if temperature == 0: |
| next_tokens = torch.argmax(logits, dim=-1) |
| else: |
| probs = F.softmax(logits / max(temperature, 1e-8), dim=-1) |
| next_tokens = torch.multinomial(probs, 1).squeeze(-1) |
|
|
| newly_done: set = set() |
| for bi in range(B_cur): |
| oi = active_map[bi] |
| tok = int(next_tokens[bi]) |
| generated[oi].append(tok) |
| if tok == eos_id: |
| finished[oi] = True |
| newly_done.add(bi) |
| else: |
| token_buf[bi, 0] = tok |
|
|
| if all(finished): |
| break |
|
|
| if newly_done: |
| keep = [bi for bi in range(B_cur) if bi not in newly_done] |
| if not keep: |
| break |
| keep_idx = torch.tensor(keep, device=device) |
| token_buf = token_buf[keep_idx].contiguous() |
| states = [tuple(s[keep_idx].contiguous() for s in st) for st in states] |
| logits = logits[keep_idx] |
| active_map = [active_map[bi] for bi in keep] |
|
|
| seq_len += 1 |
| logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton) |
|
|
| results = [] |
| for toks in generated: |
| if toks and toks[-1] == eos_id: |
| toks = toks[:-1] |
| text = tokenizer.decode(toks) |
| if strip_thinking and "<|think_end|>" in text: |
| text = text.split("<|think_end|>", 1)[1].strip() |
| results.append(text) |
| return results |
|
|
|
|
| @torch.no_grad() |
| def stream( |
| model, |
| tokenizer, |
| prompt: str, |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| use_chat_template: bool = True, |
| use_triton: bool = False, |
| ) -> Iterator[str]: |
| """ |
| Streaming token-by-token generation. |
| |
| Yields decoded text fragments as they are produced. Useful for interactive |
| applications (e.g., a chat interface). |
| |
| Example: |
| for fragment in stream(model, tokenizer, "Explain gravity."): |
| print(fragment, end="", flush=True) |
| """ |
| device = next(model.parameters()).device |
| eos_id = tokenizer.im_end_id |
|
|
| if use_chat_template: |
| ids = tokenizer.encode_chat(prompt, add_think_start=True) |
| else: |
| ids = tokenizer.encode(prompt) |
|
|
| input_ids = torch.tensor([ids], dtype=torch.long, device=device) |
| logits, states = model.model.prefill(input_ids) |
| logits = logits.squeeze(1) |
|
|
| token_buf = torch.zeros((1, 1), dtype=torch.long, device=device) |
| seq_len = len(ids) |
|
|
| for _ in range(max_new_tokens): |
| if temperature == 0: |
| next_token = int(torch.argmax(logits[0])) |
| else: |
| probs = F.softmax(logits[0] / max(temperature, 1e-8), dim=-1) |
| next_token = int(torch.multinomial(probs, 1)) |
|
|
| if next_token == eos_id: |
| break |
|
|
| try: |
| yield tokenizer.decode([next_token]) |
| except Exception: |
| yield "" |
|
|
| token_buf[0, 0] = next_token |
| seq_len += 1 |
| logits, states = model.model.step(token_buf, states, seq_len=seq_len, use_triton=use_triton) |
|
|