| import argparse |
| import time |
| import mlx.core as mx |
| from transformers import AutoTokenizer |
| from model import load_model |
| from pathlib import Path |
|
|
|
|
| def generate_text( |
| prompt: str, |
| model_path: str, |
| max_tokens: int = 100, |
| temperature: float = 0.1, |
| top_p: float = 0.9, |
| system: str | None = None, |
| final_only: bool = False, |
| stop_at_boxed: bool = False, |
| extract_boxed: bool = False, |
| disable_chat_template: bool = False, |
| repetition_penalty: float = 1.0, |
| frequency_penalty: float = 0.0, |
| ): |
| """Generates text using the loaded MLX model with better sampling.""" |
| print("Loading model and tokenizer...") |
| model = load_model(model_path) |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
| |
| chat_template_path = Path(model_path) / "chat_template.jinja" |
| use_chat_format = chat_template_path.exists() and not disable_chat_template |
|
|
| print(f"Chat template found: {use_chat_format}") |
| print("Starting generation...") |
| print(f"Prompt: {prompt}") |
|
|
| |
| if use_chat_format: |
| messages = [] |
| if system is None and final_only: |
| system = ( |
| "You are a helpful assistant. Do not reveal your reasoning. " |
| "Respond with only the final answer enclosed in \\boxed{...}." |
| ) |
| if system is not None: |
| messages.append({"role": "system", "content": system}) |
| messages.append({"role": "user", "content": prompt}) |
| formatted_prompt = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| print(f"Formatted prompt: {formatted_prompt}") |
| else: |
| |
| bos = tokenizer.bos_token or "" |
| formatted_prompt = f"{bos}{prompt}" |
|
|
| |
| prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False) |
| prompt_tokens = mx.array([prompt_tokens]) |
|
|
| print(f"Prompt tokens shape: {prompt_tokens.shape}") |
| print( |
| f"First few token IDs: {prompt_tokens[0, : min(10, prompt_tokens.shape[1])].tolist()}" |
| ) |
|
|
| |
| start_time = time.time() |
| generated_tokens = [] |
| freq_counts = {} |
|
|
| running_text = "" |
| seen_box_start = False |
| for i in range(max_tokens): |
| |
| logits = model(prompt_tokens) |
|
|
| |
| next_token_logits = logits[0, -1, :] |
|
|
| |
| if repetition_penalty and repetition_penalty != 1.0 and generated_tokens: |
| |
| |
| logits_list = next_token_logits.tolist() |
| seen = set(generated_tokens) |
| for tid in seen: |
| val = logits_list[tid] |
| if val > 0: |
| logits_list[tid] = val / repetition_penalty |
| else: |
| logits_list[tid] = val * repetition_penalty |
| next_token_logits = mx.array(logits_list) |
|
|
| if frequency_penalty and frequency_penalty > 0 and generated_tokens: |
| |
| counts = {} |
| for t in generated_tokens: |
| counts[t] = counts.get(t, 0) + 1 |
| |
| vocab_size = next_token_logits.shape[-1] |
| pen = [0.0] * vocab_size |
| for tid, c in counts.items(): |
| pen[tid] = frequency_penalty * float(c) |
| next_token_logits = next_token_logits - mx.array(pen) |
|
|
| |
| if temperature == 0: |
| |
| next_token = int(mx.argmax(next_token_logits).item()) |
| else: |
| |
| scaled_logits = next_token_logits / temperature |
|
|
| if 0.0 < top_p < 1.0: |
| probs = mx.softmax(scaled_logits, axis=-1) |
| sorted_probs = mx.sort(probs)[::-1] |
| cumulative_probs = mx.cumsum(sorted_probs, axis=-1) |
| cutoff_index = mx.sum(cumulative_probs < top_p) |
| cutoff_prob = sorted_probs[cutoff_index.item()] |
| mask = probs >= cutoff_prob |
| scaled_logits = mx.where(mask, scaled_logits, float("-inf")) |
|
|
| |
| next_token = mx.random.categorical(scaled_logits, num_samples=1).item() |
|
|
| |
| eos_ids = tokenizer.eos_token_id |
| if isinstance(eos_ids, (list, tuple)): |
| stop_ids = set(int(i) for i in eos_ids) |
| else: |
| stop_ids = {int(eos_ids)} |
| if next_token in stop_ids: |
| print(f"Stopping generation at EOS token: {next_token}") |
| break |
|
|
| generated_tokens.append(next_token) |
| |
| freq_counts[next_token] = freq_counts.get(next_token, 0) + 1 |
| |
| prompt_tokens = mx.concatenate( |
| [prompt_tokens, mx.array([[next_token]])], axis=1 |
| ) |
|
|
| |
| if i < 10: |
| token_text = tokenizer.decode([next_token]) |
| print(f"Token {i}: {next_token} -> '{token_text}'") |
|
|
| |
| if stop_at_boxed: |
| token_text_full = tokenizer.decode([next_token], skip_special_tokens=False) |
| running_text += token_text_full |
| if not seen_box_start and "\\boxed{" in running_text: |
| seen_box_start = True |
| if seen_box_start and "}" in running_text: |
| print("Stopping generation at boxed answer.") |
| break |
|
|
| end_time = time.time() |
|
|
| |
| if generated_tokens: |
| response = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| print("\n--- Response ---") |
| print(response) |
| else: |
| print("\n--- No tokens generated ---") |
|
|
| print("------------------") |
|
|
| generation_speed = ( |
| len(generated_tokens) / (end_time - start_time) if generated_tokens else 0 |
| ) |
| print(f"Generated {len(generated_tokens)} tokens") |
| print(f"Generation speed: {generation_speed:.2f} tokens/sec") |
|
|
| |
| if generated_tokens: |
| full_response = tokenizer.decode(generated_tokens, skip_special_tokens=False) |
| print(f"\nFull response (with special tokens): '{full_response}'") |
|
|
| if extract_boxed and generated_tokens: |
| import re |
| m = None |
| |
| for m in re.finditer(r"\\\\boxed\{([^}]*)\}", full_response): |
| pass |
| if m: |
| print(f"\nExtracted boxed answer: {m.group(1).strip()}") |
| else: |
| print("\nNo \\boxed{...} segment found to extract.") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run inference with the MLX model.") |
| parser.add_argument( |
| "--model-path", type=str, default=".", help="Path to the model directory." |
| ) |
| parser.add_argument( |
| "--prompt", |
| type=str, |
| default="What is the capital of France?", |
| help="The prompt to start generation from.", |
| ) |
| parser.add_argument( |
| "--max-tokens", |
| type=int, |
| default=100, |
| help="The maximum number of tokens to generate.", |
| ) |
| parser.add_argument( |
| "--temperature", type=float, default=0.1, help="Sampling temperature." |
| ) |
| parser.add_argument( |
| "--top-p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter." |
| ) |
| parser.add_argument( |
| "--system", type=str, default=None, help="Optional system message for chat template." |
| ) |
| parser.add_argument( |
| "--final-only", |
| action="store_true", |
| help="Instruct the model to output only the final answer inside \\boxed{...}.", |
| ) |
| parser.add_argument( |
| "--stop-at-boxed", |
| action="store_true", |
| help="Stop generation once a closing '}' appears after \\boxed{.", |
| ) |
| parser.add_argument( |
| "--extract-boxed", |
| action="store_true", |
| help="Extract and print the content inside the last \\boxed{...} in the response.", |
| ) |
| parser.add_argument( |
| "--disable-chat-template", |
| action="store_true", |
| help="Ignore chat_template.jinja and feed the raw prompt (prepended with BOS).", |
| ) |
| parser.add_argument( |
| "--repetition-penalty", |
| type=float, |
| default=1.0, |
| help="Penalty (>1.0) to discourage previously generated tokens.", |
| ) |
| parser.add_argument( |
| "--frequency-penalty", |
| type=float, |
| default=0.0, |
| help="Subtract alpha * count(token) from logits before sampling.", |
| ) |
| args = parser.parse_args() |
|
|
| generate_text( |
| args.prompt, |
| args.model_path, |
| args.max_tokens, |
| args.temperature, |
| args.top_p, |
| args.system, |
| args.final_only, |
| args.stop_at_boxed, |
| args.extract_boxed, |
| args.disable_chat_template, |
| args.repetition_penalty, |
| args.frequency_penalty, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|