BioRLHF / examples /evaluate_model.py
jang1563's picture
Initial commit: BioRLHF v0.1.0
c7ebaa1
#!/usr/bin/env python3
"""
BioRLHF Model Evaluation Example
This script demonstrates how to evaluate a fine-tuned model on
biological reasoning tasks.
Usage:
python evaluate_model.py --model ./biorlhf_model --test-set kmp_test_set.json
"""
import argparse
import json
from pathlib import Path
from biorlhf import evaluate_model
def main():
"""Run model evaluation."""
parser = argparse.ArgumentParser(
description="Evaluate a fine-tuned BioRLHF model"
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the fine-tuned model directory",
)
parser.add_argument(
"--test-set",
type=str,
default="kmp_test_set.json",
help="Path to test questions JSON file",
)
parser.add_argument(
"--base-model",
type=str,
default="mistralai/Mistral-7B-v0.3",
help="Base model name",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Output path for detailed results JSON",
)
parser.add_argument(
"--no-quantization",
action="store_true",
help="Disable 4-bit quantization",
)
parser.add_argument(
"--temperature",
type=float,
default=0.1,
help="Generation temperature (0 for greedy)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum tokens to generate",
)
args = parser.parse_args()
# Check if test set exists
if not Path(args.test_set).exists():
print(f"Error: Test set not found at {args.test_set}")
print("\nYou can create a test set or use the default one from the data folder.")
return
print("=" * 60)
print("BioRLHF Model Evaluation")
print("=" * 60)
print(f"Model: {args.model}")
print(f"Base Model: {args.base_model}")
print(f"Test Set: {args.test_set}")
print(f"Quantization: {'Disabled' if args.no_quantization else '4-bit'}")
print("=" * 60)
# Run evaluation
results = evaluate_model(
model_path=args.model,
test_questions_path=args.test_set,
base_model=args.base_model,
use_4bit=not args.no_quantization,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
)
# Print results
print("\n" + "=" * 60)
print("EVALUATION RESULTS")
print("=" * 60)
print(f"\nOverall Accuracy: {results.overall_accuracy:.1%} ({results.correct_answers}/{results.total_questions})")
print(f"\nBy Category:")
print(f" Factual: {results.factual_accuracy:.1%}")
print(f" Reasoning: {results.reasoning_accuracy:.1%}")
print(f" Calibration: {results.calibration_accuracy:.1%}")
# Show detailed results
print("\n" + "-" * 60)
print("Detailed Results:")
print("-" * 60)
for i, r in enumerate(results.detailed_results, 1):
status = "CORRECT" if r["correct"] else "WRONG"
print(f"\n{i}. [{r['category'].upper()}] {status}")
print(f" Q: {r['question'][:80]}...")
print(f" Expected: {r['expected'][:50]}..." if len(r["expected"]) > 50 else f" Expected: {r['expected']}")
print(f" Response: {r['response'][:100]}..." if len(r["response"]) > 100 else f" Response: {r['response']}")
# Save detailed results if requested
if args.output:
output_data = {
"model_path": args.model,
"base_model": args.base_model,
"test_set": args.test_set,
"metrics": {
"overall_accuracy": results.overall_accuracy,
"factual_accuracy": results.factual_accuracy,
"reasoning_accuracy": results.reasoning_accuracy,
"calibration_accuracy": results.calibration_accuracy,
"total_questions": results.total_questions,
"correct_answers": results.correct_answers,
},
"detailed_results": results.detailed_results,
}
with open(args.output, "w") as f:
json.dump(output_data, f, indent=2)
print(f"\nDetailed results saved to: {args.output}")
print("\n" + "=" * 60)
print("Evaluation complete!")
print("=" * 60)
if __name__ == "__main__":
main()