|
|
| """
|
| Script to use a trained GRPO model for arithmetic countdown problems.
|
|
|
| This script loads a model trained with train_grpo_hydra.py and provides
|
| both interactive and batch evaluation modes for solving arithmetic problems.
|
| """
|
|
|
| import logging
|
| import sys
|
| from pathlib import Path
|
|
|
| import torch
|
| from peft import PeftModel
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
|
|
|
| sys.path.append(str(Path(__file__).parent.parent))
|
|
|
| from src.dataset.grpo import map_problem_description_to_conversation_grpo
|
| from src.utils.rewards import _is_valid_arithmetic_expression
|
| from src.utils.string_helper import extract_answer
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| )
|
| logger = logging.getLogger("model_inference")
|
|
|
|
|
| class GRPOModelInference:
|
| """Class for loading and running inference with a trained GRPO model."""
|
|
|
| def __init__(
|
| self,
|
| sft_model_path: str | None = None,
|
| grpo_model_path: str | None = None,
|
| base_model_id: str = "Qwen/Qwen2.5-Math-1.5B",
|
| device: str = "auto",
|
| dtype: torch.dtype = torch.float16,
|
| ):
|
| """
|
| Initialize the model inference class.
|
|
|
| Args:
|
| model_path: Path to the trained LoRA model directory
|
| base_model_id: Base model identifier from Hugging Face
|
| device: Device to load the model on
|
| dtype: Torch data type for the model
|
| """
|
| self.sft_model_path = sft_model_path
|
| self.grpo_model_path = grpo_model_path
|
| self.base_model_id = base_model_id
|
| self.device = device
|
| self.dtype = dtype
|
|
|
| self.tokenizer = None
|
| self.model = None
|
|
|
| self._load_model()
|
|
|
| def _load_model(self) -> None:
|
| """Load the base model, LoRA adapters, and tokenizer."""
|
| logger.info(f"Loading base model: {self.base_model_id}")
|
|
|
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
|
|
|
|
|
| self.model = AutoModelForCausalLM.from_pretrained(
|
| self.base_model_id,
|
| dtype=self.dtype,
|
| device_map=self.device,
|
| )
|
|
|
|
|
| if self.sft_model_path:
|
| logger.info(f"Loading SFT LoRA adapters from: {self.sft_model_path}")
|
| self.model = PeftModel.from_pretrained(self.model, self.sft_model_path)
|
| self.model = self.model.merge_and_unload()
|
|
|
|
|
| if self.grpo_model_path:
|
| logger.info(f"Loading GRPO LoRA adapters from: {self.grpo_model_path}")
|
| self.model = PeftModel.from_pretrained(self.model, self.grpo_model_path)
|
| self.model = self.model.merge_and_unload()
|
|
|
| self.model.eval()
|
| logger.info("Model loaded successfully")
|
|
|
| def _format_conversation(self, problem_description: str) -> list[dict[str, str]]:
|
| """
|
| Format the problem description into the expected conversation format.
|
|
|
| Args:
|
| problem_description: The arithmetic problem description
|
|
|
| Returns:
|
| List of conversation messages
|
| """
|
| result = map_problem_description_to_conversation_grpo(
|
| {
|
| "problem_description": problem_description,
|
| }
|
| )
|
| return result["prompt"]
|
|
|
| def _generate_response(
|
| self,
|
| messages: list[dict[str, str]],
|
| max_new_tokens: int = 512,
|
| temperature: float = 0.7,
|
| do_sample: bool = True,
|
| top_p: float = 0.9,
|
| ) -> str:
|
| """
|
| Generate a response from the model given conversation messages.
|
|
|
| Args:
|
| messages: List of conversation messages
|
| max_new_tokens: Maximum number of new tokens to generate
|
| temperature: Sampling temperature
|
| do_sample: Whether to use sampling
|
| top_p: Top-p sampling parameter
|
|
|
| Returns:
|
| Generated response text
|
| """
|
|
|
| formatted_prompt = self.tokenizer.apply_chat_template(
|
| messages, tokenize=False, add_generation_prompt=True
|
| )
|
|
|
|
|
| inputs = self.tokenizer(
|
| formatted_prompt,
|
| return_tensors="pt",
|
| padding=True,
|
| truncation=True,
|
| max_length=4096,
|
| )
|
|
|
|
|
| if hasattr(self.model, "device"):
|
| device = self.model.device
|
| else:
|
| device = next(self.model.parameters()).device
|
|
|
| inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
|
| generation_config = GenerationConfig(
|
| max_new_tokens=max_new_tokens,
|
| temperature=temperature,
|
| do_sample=do_sample,
|
| top_p=top_p,
|
| pad_token_id=self.tokenizer.pad_token_id,
|
| eos_token_id=self.tokenizer.eos_token_id,
|
| )
|
|
|
| with torch.no_grad():
|
| outputs = self.model.generate(
|
| **inputs,
|
| generation_config=generation_config,
|
| )
|
|
|
|
|
| response = self.tokenizer.decode(
|
| outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
| )
|
|
|
| return response.strip()
|
|
|
| def solve_problem(
|
| self,
|
| problem_description: str,
|
| max_new_tokens: int = 512,
|
| temperature: float = 0.7,
|
| ) -> tuple[str, str, bool]:
|
| """
|
| Solve a single arithmetic countdown problem.
|
|
|
| Args:
|
| problem_description: The problem description
|
| max_new_tokens: Maximum tokens to generate
|
| temperature: Sampling temperature
|
| verbose: Whether to print detailed output
|
|
|
| Returns:
|
| Tuple of (full_response, extracted_answer, is_valid_format)
|
| """
|
|
|
| messages = self._format_conversation(problem_description)
|
|
|
|
|
| response = self._generate_response(
|
| messages, max_new_tokens=max_new_tokens, temperature=temperature
|
| )
|
|
|
|
|
| extracted_answer = extract_answer(response)
|
| is_valid = _is_valid_arithmetic_expression(extracted_answer)
|
|
|
| return response, extracted_answer, is_valid
|
|
|