| |
| |
| """ |
| Evaluate models on a LongBench subset with Exact-Match (EM). |
| Supports both Qwen3 (Transformers) and other models (vLLM). |
| |
| Requirements |
| ------------ |
| pip install vllm datasets tqdm transformers accelerate |
| """ |
|
|
| import argparse, logging, time, torch |
| from pathlib import Path |
|
|
| from datasets import load_dataset |
| from tqdm import tqdm |
| from utils.metrics import qa_em_score |
| import os |
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hf_model", |
| default="Qwen/Qwen3-8B-Instruct", |
| help="Model name or local path") |
| parser.add_argument("--is_qwen3", action="store_true", |
| help="Set this flag if using Qwen3 model (uses Transformers). Otherwise uses vLLM.") |
| parser.add_argument("--max_new_tokens", type=int, default=20) |
| parser.add_argument("--max_tokens", type=int, default=20, |
| help="For vLLM models (ignored if --is_qwen3)") |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--top_p", type=float, default=1.0) |
| parser.add_argument("--tensor_parallel_size", type=int, default=2, |
| help="GPU parallel size for vLLM (ignored if --is_qwen3)") |
|
|
| parser.add_argument("--dataset_repo", default="THUDM/LongBench") |
| parser.add_argument("--dataset_subset", default="hotpotqa") |
| parser.add_argument("--split", default="test") |
| parser.add_argument("--sleep", type=float, default=0.0) |
| parser.add_argument("--log", default="summary.log") |
| parser.add_argument("--cuda_devices", default="1,6", |
| help="CUDA visible devices") |
| args = parser.parse_args() |
|
|
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices |
|
|
| |
| logging.basicConfig( |
| filename=args.log, |
| level=logging.INFO, |
| format="%(asctime)s - %(message)s", |
| filemode="a", |
| ) |
| logging.getLogger().addHandler(logging.StreamHandler()) |
|
|
| |
| ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split) |
| total = len(ds) |
| logging.info("Loaded %d samples from %s/%s[%s]", |
| total, args.dataset_repo, args.dataset_subset, args.split) |
|
|
| if args.is_qwen3: |
| |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
| load_kwargs = dict( |
| trust_remote_code=True, |
| device_map="auto", |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.hf_model, |
| trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| args.hf_model, |
| torch_dtype=torch.float16, |
| **load_kwargs |
| ) |
|
|
| EOS_ID = tokenizer.eos_token_id |
| THINK_ENDID = 151668 |
|
|
| gen_kwargs = dict( |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| do_sample=args.temperature > 0, |
| eos_token_id=EOS_ID, |
| ) |
|
|
| |
| correct_em = 0 |
|
|
| for ex in tqdm(ds, desc="Evaluating with Transformers (Qwen3)"): |
| q = ex["input"] |
| golds = ex["answers"] |
|
|
| msgs = [ |
| {"role": "system", "content": "You are a QA assistant."}, |
| {"role": "user", |
| "content": f"Question: {q}\n" |
| "Please reply with *only* the final answer—no extra words."} |
| ] |
| prompt = tokenizer.apply_chat_template( |
| msgs, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=False |
| ) |
| inputs = tokenizer([prompt], return_tensors="pt").to(model.device) |
|
|
| with torch.no_grad(): |
| outs = model.generate(**inputs, **gen_kwargs)[0] |
|
|
| |
| new_ids = outs[len(inputs.input_ids[0]):].tolist() |
|
|
| |
| try: |
| idx = len(new_ids) - new_ids[::-1].index(THINK_ENDID) |
| except ValueError: |
| idx = 0 |
|
|
| content = tokenizer.decode(new_ids[idx:], |
| skip_special_tokens=True).strip("\n").strip() |
|
|
| |
| if any(qa_em_score(content, g) for g in golds): |
| correct_em += 1 |
|
|
| if args.sleep: |
| time.sleep(args.sleep) |
|
|
| else: |
| |
| from vllm import LLM, SamplingParams |
| |
| |
| llm = LLM( |
| model=args.hf_model, |
| tensor_parallel_size=args.tensor_parallel_size, |
| ) |
| sampler = SamplingParams( |
| temperature=args.temperature, |
| max_tokens=args.max_tokens, |
| top_p=args.top_p, |
| stop=["</assistant>", "</s>", "<|end_of_text|>"], |
| ) |
|
|
| |
| correct_em = 0 |
|
|
| for ex in tqdm(ds, desc="Evaluating with vLLM"): |
| question = ex["input"] |
| golds = ex["answers"] |
| |
| chat_params = SamplingParams( |
| temperature=args.temperature, |
| max_tokens=args.max_tokens, |
| top_p=args.top_p, |
| stop=["</s>", "<|end_of_text|>"], |
| ) |
| |
| messages = [ |
| {"role": "system", |
| "content": "You are a QA assistant."}, |
| {"role": "user", |
| "content": f"Question: {question}\n" |
| "Please first reply with *only* the final answer—no extra words.\n Answer:"} |
| ] |
|
|
| result = llm.chat(messages, sampling_params=chat_params) |
| |
| pred = result[0].outputs[0].text.strip() |
| print(f"A: {pred}\nG: {golds}\n") |
|
|
| if any(qa_em_score(pred, g) for g in golds): |
| correct_em += 1 |
|
|
| if args.sleep: |
| time.sleep(args.sleep) |
|
|
| |
| em = correct_em / total |
| model_type = "Qwen3 (Transformers)" if args.is_qwen3 else "vLLM" |
| logging.info("RESULT | model=%s | type=%s | subset=%s | EM=%.4f", |
| args.hf_model, model_type, args.dataset_subset, em) |
| print( |
| f"\n=== SUMMARY ===\n" |
| f"Model : {args.hf_model}\n" |
| f"Type : {model_type}\n" |
| f"Subset : {args.dataset_subset} ({args.split})\n" |
| f"EM : {em:.4f}\n" |
| f"(Log in {Path(args.log).resolve()})" |
| ) |