| """inference.py - Code generation model wrapper for smolagents""" |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| class CodeModel: |
| def __init__(self, model_id: str, device: str = None): |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, fix_mistral_regex=True) |
| dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
| self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device, dtype=dtype) |
| self.model.eval() |
|
|
| def generate(self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7) -> str: |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| do_sample=True, |
| top_p=0.9, |
| repetition_penalty=1.2, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| ) |
|
|
| new_tokens = outputs[0, inputs["input_ids"].shape[1]:] |
| return self.tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
| def chat(self, messages: list[dict], max_new_tokens: int = 256) -> str: |
| """Generate response using chat template.""" |
| text = self.tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=False |
| ) |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
|
|
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9, |
| repetition_penalty=1.2, |
| ) |
|
|
| new_tokens = outputs[0, inputs["input_ids"].shape[1]:] |
| return self.tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
|
|
| if __name__ == "__main__": |
| import os |
| |
| model_id = "checkpoint" if os.path.exists("checkpoint") else "AutomatedScientist/pynb-73m-base" |
| model = CodeModel(model_id) |
|
|
| |
| result = model.generate("Write a Python function to calculate factorial") |
| print("Generated code:") |
| print(result) |
|
|
| |
| messages = [ |
| {"role": "system", "content": "You are a helpful coding assistant."}, |
| {"role": "user", "content": "Write a function to reverse a string"} |
| ] |
| response = model.chat(messages) |
| print("\nChat response:") |
| print(response) |
|
|