| |
| """ |
| BioRLHF Model Evaluation Example |
| |
| This script demonstrates how to evaluate a fine-tuned model on |
| biological reasoning tasks. |
| |
| Usage: |
| python evaluate_model.py --model ./biorlhf_model --test-set kmp_test_set.json |
| """ |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| from biorlhf import evaluate_model |
|
|
|
|
| def main(): |
| """Run model evaluation.""" |
| parser = argparse.ArgumentParser( |
| description="Evaluate a fine-tuned BioRLHF model" |
| ) |
| parser.add_argument( |
| "--model", |
| type=str, |
| required=True, |
| help="Path to the fine-tuned model directory", |
| ) |
| parser.add_argument( |
| "--test-set", |
| type=str, |
| default="kmp_test_set.json", |
| help="Path to test questions JSON file", |
| ) |
| parser.add_argument( |
| "--base-model", |
| type=str, |
| default="mistralai/Mistral-7B-v0.3", |
| help="Base model name", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| default=None, |
| help="Output path for detailed results JSON", |
| ) |
| parser.add_argument( |
| "--no-quantization", |
| action="store_true", |
| help="Disable 4-bit quantization", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0.1, |
| help="Generation temperature (0 for greedy)", |
| ) |
| parser.add_argument( |
| "--max-tokens", |
| type=int, |
| default=512, |
| help="Maximum tokens to generate", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if not Path(args.test_set).exists(): |
| print(f"Error: Test set not found at {args.test_set}") |
| print("\nYou can create a test set or use the default one from the data folder.") |
| return |
|
|
| print("=" * 60) |
| print("BioRLHF Model Evaluation") |
| print("=" * 60) |
| print(f"Model: {args.model}") |
| print(f"Base Model: {args.base_model}") |
| print(f"Test Set: {args.test_set}") |
| print(f"Quantization: {'Disabled' if args.no_quantization else '4-bit'}") |
| print("=" * 60) |
|
|
| |
| results = evaluate_model( |
| model_path=args.model, |
| test_questions_path=args.test_set, |
| base_model=args.base_model, |
| use_4bit=not args.no_quantization, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| ) |
|
|
| |
| print("\n" + "=" * 60) |
| print("EVALUATION RESULTS") |
| print("=" * 60) |
| print(f"\nOverall Accuracy: {results.overall_accuracy:.1%} ({results.correct_answers}/{results.total_questions})") |
| print(f"\nBy Category:") |
| print(f" Factual: {results.factual_accuracy:.1%}") |
| print(f" Reasoning: {results.reasoning_accuracy:.1%}") |
| print(f" Calibration: {results.calibration_accuracy:.1%}") |
|
|
| |
| print("\n" + "-" * 60) |
| print("Detailed Results:") |
| print("-" * 60) |
|
|
| for i, r in enumerate(results.detailed_results, 1): |
| status = "CORRECT" if r["correct"] else "WRONG" |
| print(f"\n{i}. [{r['category'].upper()}] {status}") |
| print(f" Q: {r['question'][:80]}...") |
| print(f" Expected: {r['expected'][:50]}..." if len(r["expected"]) > 50 else f" Expected: {r['expected']}") |
| print(f" Response: {r['response'][:100]}..." if len(r["response"]) > 100 else f" Response: {r['response']}") |
|
|
| |
| if args.output: |
| output_data = { |
| "model_path": args.model, |
| "base_model": args.base_model, |
| "test_set": args.test_set, |
| "metrics": { |
| "overall_accuracy": results.overall_accuracy, |
| "factual_accuracy": results.factual_accuracy, |
| "reasoning_accuracy": results.reasoning_accuracy, |
| "calibration_accuracy": results.calibration_accuracy, |
| "total_questions": results.total_questions, |
| "correct_answers": results.correct_answers, |
| }, |
| "detailed_results": results.detailed_results, |
| } |
|
|
| with open(args.output, "w") as f: |
| json.dump(output_data, f, indent=2) |
|
|
| print(f"\nDetailed results saved to: {args.output}") |
|
|
| print("\n" + "=" * 60) |
| print("Evaluation complete!") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|