|
|
| """
|
| Run inference using SFT local model + GRPO model from HuggingFace.
|
| """
|
|
|
| import os
|
| import sys
|
| import logging
|
| from pathlib import Path
|
| import torch
|
|
|
|
|
|
|
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| sys.path.append(str(PROJECT_ROOT))
|
|
|
|
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO,
|
| format="%(asctime)s - %(levelname)s - %(message)s"
|
| )
|
| logger = logging.getLogger("model_inference")
|
|
|
|
|
|
|
|
|
| try:
|
| from src.utils.inference import GRPOModelInference
|
| except Exception as e:
|
| print("β Import src.utils.inference FAILED!")
|
| print("PROJECT_ROOT =", PROJECT_ROOT)
|
| print("Error =", e)
|
| sys.exit(1)
|
|
|
|
|
| def main():
|
| print("π Starting inference...")
|
|
|
|
|
|
|
|
|
| try:
|
| model_inference = GRPOModelInference(
|
| sft_model_path="models/sft/checkpoint-100",
|
| grpo_model_path="Dat1710/countdown-grpo-qwen2",
|
| base_model_id="Qwen/Qwen2.5-Math-1.5B",
|
| device="auto",
|
| dtype=torch.float16,
|
| )
|
| print("β
Models loaded successfully!")
|
| except Exception as e:
|
| print("β Failed to load models!")
|
| print(e)
|
| return
|
|
|
|
|
|
|
|
|
| problem = (
|
| "Your task: Use 53, 3, 47, and 36 exactly once each with "
|
| "only +, -, *, and / operators to create an expression equal to 133."
|
| )
|
|
|
| print("\nπ Problem:")
|
| print(problem)
|
|
|
| response, extracted_answer, is_valid = model_inference.solve_problem(
|
| problem_description=problem,
|
| max_new_tokens=512,
|
| temperature=0.8,
|
| )
|
|
|
|
|
|
|
|
|
| print("\n================= MODEL OUTPUT =================")
|
| print(response)
|
| print("================================================\n")
|
|
|
| print("π Extracted Answer:", extracted_answer)
|
| print("π Valid format:", is_valid)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|