| import json |
| import time |
| import re |
| import os |
| import argparse |
| from datasets import load_dataset |
| from nltk.tokenize import sent_tokenize |
| from utils.util import retriveDoc,compute_best_sentence_f1 |
| from openai import OpenAI |
| import asyncio, json, torch, math |
| from typing import List, Tuple |
| |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| from utils.metrics import qa_f1_score |
| from utils.llmjudge import judge_answer_with_api |
|
|
|
|
| client = OpenAI( |
| base_url=os.environ.get("OPENAI_BASE_URL"), |
| api_key=os.environ.get("OPENAI_API_KEY") |
| ) |
| |
|
|
| tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True) |
| model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16) |
|
|
|
|
| tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True) |
| model_qwen = AutoModelForCausalLM.from_pretrained( |
| "Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True, |
| device_map="cuda:1",torch_dtype=torch.bfloat16 |
| ).eval() |
|
|
| def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5): |
| """ |
| Use transformers model.generate method for inference with retry mechanism, |
| use chat template to format input, and strip the input prompt part through token-level slicing, |
| return the newly generated text. |
| """ |
| import time |
| for attempt in range(retries): |
| try: |
| |
| messages = [{"role": "user", "content": prompt}] |
| |
| |
| try: |
| formatted_prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| except Exception as e: |
| print(f"Unable to apply chat template: {e}, falling back to basic text input") |
| formatted_prompt = prompt |
| |
| |
| model_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
| |
| |
| generated_ids = model.generate( |
| **model_inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p |
| ) |
| |
| |
| input_length = model_inputs.input_ids.shape[1] |
| |
| |
| output_ids = generated_ids[0][input_length:] |
| |
| |
| answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip() |
| return answer |
| except Exception as e: |
| print(f"Error on attempt {attempt + 1}: {e}") |
| if attempt < retries - 1: |
| print(f"Retrying in {delay} seconds...") |
| time.sleep(delay) |
| else: |
| print("Max retries reached, skipping this request.") |
| return None |
|
|
| def truncate_answer(answer): |
| """Truncate answer, only take the part before the first period""" |
| return answer.split('.')[0].strip() if answer else "No answer" |
|
|
| def write_to_log(filename, data): |
| """Write data to log file""" |
| with open(filename, 'a', encoding='utf-8') as file: |
| file.write(data + '\n') |
|
|
| def remove_think_tags(text: str) -> str: |
| """Remove all <think> ... </think> blocks""" |
| return re.sub(r'<think>(.*?)</think>', '', text, flags=re.DOTALL).strip() |
|
|
| def build_prompt(context: str, question: str) -> str: |
| prompt = ( |
| f"Answer the question based on the given passages. The following are the passages:\n" |
| f"{context}\n" |
| f"Answer the question based on the given passages.\n" |
| f"Question: {question}.\n" |
| f"Answer:\n" |
| f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) " |
| f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. " |
| f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation. Do not use ellipses inside the sentence. " |
| f"Follow this format:\n" |
| f"Answer: [Your answer]\n" |
| f"Step-by-step Reasoning:\n" |
| f"1. [Reasoning step 1]\n" |
| f"[replaced by your reference content]\n" |
| f"2. [Reasoning step 2]\n" |
| f"[replaced by your reference content]\n" |
| ) |
| return prompt |
|
|
| def extract_final_bullet_passage(answer_text: str): |
| reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)" |
| reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) |
| if not reasoning_match: |
| return None, None |
|
|
| reasoning_text = reasoning_match.group(1).strip() |
| bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)" |
| bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL) |
| if not bullets: |
| print("No bullet blocks found.") |
| return None, None |
|
|
| passage_pattern = re.compile( |
| r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)', |
| flags=re.DOTALL |
| ) |
| |
| for bullet in reversed(bullets): |
| matches = passage_pattern.findall(bullet) |
| if matches: |
| last_match = matches[-1] |
| passage_number = last_match[0] |
| quoted_snippet = last_match[2] |
| non_quoted_snippet = last_match[3] |
| snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip() |
| return passage_number, snippet |
|
|
| return None, None |
|
|
| def extract_all_bullet_passages(answer_text: str): |
| reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)" |
| reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) |
| if not reasoning_match: |
| return [] |
|
|
| reasoning_text = reasoning_match.group(1).strip() |
| bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL) |
| bullets = bullet_pattern.findall(reasoning_text) |
| if not bullets: |
| return [] |
|
|
| results = [] |
| for bullet_index, bullet_text in enumerate(bullets, start=1): |
| results.append({ |
| 'bullet_index': bullet_index, |
| 'snippet': bullet_text.strip() |
| }) |
| print(results) |
| return results |
|
|
| def extract_evidence(answer_text: str): |
| reasoning_pattern = r"(?i)Evidence\s*(.*)" |
| reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) |
| if not reasoning_match: |
| return [] |
|
|
| reasoning_text = reasoning_match.group(1).strip() |
|
|
| |
| bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL) |
| bullets = bullet_pattern.findall(reasoning_text) |
| if not bullets: |
| return [] |
|
|
| |
| start_index = -1 |
| for i, bullet in enumerate(bullets): |
| if bullet.strip().startswith("1."): |
| start_index = i |
| break |
|
|
| if start_index == -1: |
| return [] |
|
|
| |
| bullets = bullets[start_index:] |
|
|
| results = [] |
| for bullet_index, bullet_text in enumerate(bullets, start=1): |
| results.append({ |
| 'bullet_index': bullet_index, |
| 'snippet': bullet_text.strip() |
| }) |
| return results |
|
|
|
|
| def get_answer_with_retry(model, prompt, retries=3, delay=5): |
| """Call the model to get the answer based on the prompt, with retry on failure.""" |
| for attempt in range(retries): |
| try: |
| completion = client.chat.completions.create( |
| model=model, |
| messages=[{'role': 'user', 'content': prompt}] |
| ) |
| return completion.choices[0].message.content.strip() |
| except Exception as e: |
| print(f"Error on attempt {attempt + 1}: {e}") |
| if attempt < retries - 1: |
| print(f"Retrying in {delay} seconds...") |
| time.sleep(delay) |
| else: |
| print("Max retries reached, skipping this request.") |
| return None |
|
|
| def extract_json_from_gpt_response(text: str) -> dict | None: |
| """ |
| Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict. |
| """ |
| |
| m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL) |
| if not m: |
| |
| m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL) |
| if not m: |
| |
| m = re.search(r"(\{.*?\})", text, flags=re.DOTALL) |
| if not m: |
| return None |
|
|
| json_str = m.group(1) |
| try: |
| return json.loads(json_str) |
| except json.JSONDecodeError: |
| |
| cleaned = re.sub(r",\s*([\]}])", r"\1", json_str) |
| try: |
| return json.loads(cleaned) |
| except json.JSONDecodeError: |
| return None |
|
|
| async def random_alternative_answer( |
| question: str, |
| original_context: str, |
| unique_sents: List[str], |
| correct_answer: str |
| ) -> dict: |
| """Generate random alternative answer and modified evidence""" |
| |
| |
| numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents)) |
| prompt = ( |
| "You are a creative assistant. Given the question below and the original answer, propose a plausible alternative answer that is **different** from the original but still reasonable. " |
| "Then rewrite the provided sentences to support your alternative answer. When rewriting each sentence, modify only the parts necessary to support the alternative answer. " |
| "Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. " |
| "Output must be strictly in the specified JSON format, with no additional text.\n" |
| '{\n' |
| ' "answer": "<your alternative answer here, just provide the answer phrase, no need for complete sentence>",\n' |
| ' "revised": [\n' |
| ' "<rewritten sentence 1>",\n' |
| ' "<rewritten sentence 2>",\n' |
| ' ...\n' |
| ' ]\n' |
| '}\n\n' |
| f"Question:\n{question}\n\n" |
| f"Original answer:\n{correct_answer}\n\n" |
| f"Sentences to rewrite:\n{numbered}" |
| ) |
| |
| print(f"[Alternative Answer] Generating prompt: {prompt}") |
| |
| rsp = client.chat.completions.create( |
| model="gpt-4o", temperature=0.7, |
| messages=[{"role":"user","content":prompt}] |
| ) |
| |
| js = extract_json_from_gpt_response(rsp.choices[0].message.content) |
| if not js: |
| print("[Alternative Answer] Failed to parse JSON") |
| return {"context": original_context, "answer": "Failed to generate alternative"} |
| |
| revised = js["revised"] |
| alternative = js["answer"] |
| |
| |
| new_ctx = original_context |
| for old, new in zip(unique_sents, revised): |
| new_ctx = new_ctx.replace(old, new) |
| |
| return {"context": new_ctx, "answer": alternative} |
|
|
| def main(): |
| |
| parser = argparse.ArgumentParser(description="LastingBench random alternative answer generation") |
| parser.add_argument("--output", "-o", type=str, default="output_random.jsonl", |
| help="Output JSONL file path (default: output_random.jsonl)") |
| parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench", |
| help="Dataset repository name (default: THUDM/LongBench)") |
| parser.add_argument("--dataset_subset", type=str, default="hotpotqa", |
| help="Dataset subset name (default: hotpotqa)") |
| parser.add_argument("--split", type=str, default="test", |
| help="Dataset split (default: test)") |
| parser.add_argument("--start_idx", type=int, default=0, |
| help="Starting index for processing (default: 0)") |
| parser.add_argument("--max_samples", type=int, default=-1, |
| help="Maximum number of samples to process (-1 for all, default: -1)") |
| |
| args = parser.parse_args() |
| |
| out_file = args.output |
| |
| longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split] |
| |
| print(f"Output file: {out_file}") |
| print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]") |
| print(f"Total samples: {len(longbench)}") |
| |
| count = 0 |
| |
| |
| start_idx = args.start_idx |
| end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench)) |
| |
| print(f"Processing samples from index {start_idx} to {end_idx-1}") |
| |
| for idx in range(start_idx, end_idx): |
| example = longbench[idx] |
| question = example['input'] |
| print(f"Question: {question}") |
| context = example['context'] |
| correct_answer = example['answers'][0] |
|
|
| print(f"Processing example {idx + 1}:") |
| print(f"Correct Answer: {correct_answer}") |
|
|
| |
| prompt_with_context = build_prompt(context, question) |
|
|
| |
| answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context) |
| |
| |
| answer_with_context_simple = ( |
| answer_with_context |
| .split("Answer:", 1)[-1] |
| .split("Step-by-step Reasoning", 1)[0] |
| .strip() |
| ) |
| |
| print(f"Answer with context: {answer_with_context_simple}") |
| result = judge_answer_with_api(question, correct_answer, answer_with_context_simple) |
| print(f"Answer judge result: {result}") |
| |
| if not result: |
| continue |
|
|
| answer_with_context = remove_think_tags(answer_with_context or "") |
| evidence = extract_all_bullet_passages(answer_with_context) |
|
|
| page_contents = [] |
| if evidence: |
| count += 1 |
| for ev in evidence: |
| snippet = ev['snippet'] |
| result = retriveDoc(context, snippet) |
| |
| page_contents += [doc.page_content for doc in result] |
| |
| unique_page_contents = list(dict.fromkeys(page_contents)) |
| aggregated_content = "\n".join(unique_page_contents) |
| |
| prompt_final = ( |
| f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n" |
| f"Please only provide your answer. " |
| f"Your Answer:" |
| ) |
| |
| final_answer = get_transformers_answer(prompt_final, tokenizer1, model1) |
| |
| if judge_answer_with_api(question, correct_answer, final_answer): |
| print("correct") |
| else: |
| print("incorrect") |
| result_query = retriveDoc(context, question) |
| page_contents += [doc.page_content for doc in result_query] |
| |
| unique_page_contents = list(dict.fromkeys(page_contents)) |
| |
| |
| alternative = asyncio.run( |
| random_alternative_answer( |
| question, |
| context, |
| unique_page_contents, |
| correct_answer |
| ) |
| ) |
| |
| record = { |
| "question": question, |
| "answer": alternative["answer"], |
| "context": alternative["context"] |
| } |
|
|
| |
| with open(out_file, "a", encoding="utf-8") as fout: |
| fout.write(json.dumps(record, ensure_ascii=False) + "\n") |
|
|
| if __name__ == "__main__": |
| main() |