|
|
| """
|
| generate_prompts_v8_batch_fixed.py
|
|
|
| - Uses batch retrieval for Context, QA, and Relationships
|
| - Saves in batches with checkpointing
|
| - Pads contexts and QA to fixed sizes
|
| - Appends metadata at the end
|
| """
|
|
|
| import os, json, torch, numpy as np
|
| from pathlib import Path
|
| from tqdm import tqdm
|
| from sentence_transformers import SentenceTransformer
|
| from concurrent.futures import ThreadPoolExecutor
|
|
|
| from context_retreiver import retriever as context_retriever
|
| from qa_retreiver import search_topk as qa_retreiver
|
| from relationships_retreiver import batch_relationships
|
|
|
| QA_FILE = Path("got_all_qa_final.json")
|
| OUT_DIR = Path("prompts_out")
|
| CHECKPOINT_FILE = OUT_DIR / "checkpoint.json"
|
| SAVE_BATCH_SIZE = 512
|
| EMBED_BATCH_SIZE = 32
|
|
|
| DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| print(f"[INFO] Using device: {DEVICE}")
|
|
|
| EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
|
|
|
| STRUCTURAL_TOKENS = [
|
| "<|CTX_QA|>", "<|/CTX_QA|>",
|
| "<|CTX_REL|>", "<|/CTX_REL|>",
|
| "<|INSTR|>", "<|/INSTR|>",
|
| "<|QUESTION|>", "<|/QUESTION|>",
|
| "<|ANSWER|>", "<|/ANSWER|>",
|
| "<|QA_SIM_1|>", "<|/QA_SIM_1|>",
|
| "<|QA_SIM_2|>", "<|/QA_SIM_2|>",
|
| "<|QA_SIM_3|>", "<|/QA_SIM_3|>",
|
| "<|QA_SIM_4|>", "<|/QA_SIM_4|>",
|
| "<|QA_SIM_5|>", "<|/QA_SIM_5|>"
|
| ]
|
|
|
| def read_checkpoint():
|
| if CHECKPOINT_FILE.exists():
|
| try:
|
| return int(json.loads(CHECKPOINT_FILE.read_text())["next_index"])
|
| except:
|
| return 0
|
| return 0
|
|
|
| def write_checkpoint(idx):
|
| OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| CHECKPOINT_FILE.write_text(json.dumps({"next_index": idx}))
|
|
|
| def metadata_to_str(meta):
|
| if not meta: return ""
|
| return "; ".join(f"{k}={v}" for k,v in meta.items() if isinstance(v,(str,int,float,bool)))
|
|
|
| def append_metadata_at_end(answer, context1_text, context1_meta):
|
| parts=[]
|
| if answer: parts.append(answer.strip())
|
| if context1_text: parts.append(f"[Context1: {context1_text.strip()}]")
|
| meta_str = metadata_to_str(context1_meta)
|
| if meta_str: parts.append(f"(meta: {meta_str})")
|
| return " ".join(parts)
|
|
|
| def build_prompt(ctx_texts, rel_text, sim_qas, question):
|
| parts=[]
|
|
|
| for ctx in ctx_texts:
|
| if ctx: parts.append(f"<|CTX_QA|> {ctx} <|/CTX_QA|>")
|
| if rel_text: parts.append(f"<|CTX_REL|> {rel_text} <|/CTX_REL|>")
|
| for i in range(5):
|
| if i < len(sim_qas):
|
| qa = sim_qas[i]
|
| parts.append(f"<|QA_SIM_{i+1}|> Q: {qa['question']} A: {qa['answer']} <|/QA_SIM_{i+1}|>")
|
| else:
|
| parts.append(f"<|QA_SIM_{i+1}|> <|/QA_SIM_{i+1}|>")
|
| parts.append("<|INSTR|> Use above contexts to answer concisely. <|/INSTR|>")
|
| parts.append(f"<|QUESTION|> {question} <|/QUESTION|>")
|
| parts.append("<|ANSWER|>")
|
| return "\n\n".join(parts)
|
|
|
| def retrieve_contexts(questions, top_k=3):
|
| """Batch retrieve context texts + metadata"""
|
| batch_res = context_retriever.batch_retrieve(questions, top_k=top_k)
|
| contexts=[]
|
| for res_list in batch_res:
|
| ctx_texts = [r["text"] for r in res_list[:top_k]]
|
| ctx_metas = [r["metadata"] for r in res_list[:top_k]]
|
|
|
| while len(ctx_texts)<top_k: ctx_texts.append(""); ctx_metas.append({})
|
| contexts.append((ctx_texts, ctx_metas))
|
| return contexts
|
|
|
| def retrieve_qas_and_rels(questions, max_workers=20):
|
| """Threaded retrieval of QA and relationships"""
|
| sim_qas_list=[]
|
| rel_list=[]
|
| with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| sim_qas_list = list(ex.map(lambda q: qa_retreiver([q], k=5), questions))
|
| rel_list = list(ex.map(lambda q: batch_relationships([q], top_k=1)[0], questions))
|
| return sim_qas_list, rel_list
|
|
|
| def main():
|
| OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| with open(QA_FILE,'r',encoding='utf-8') as f:
|
| qas = json.load(f)
|
| total = len(qas)
|
| start_idx = read_checkpoint()
|
| if start_idx >= total:
|
| print("[INFO] Checkpoint beyond dataset length.")
|
| return
|
|
|
| prompts_accum=[]
|
| batch_count=start_idx//SAVE_BATCH_SIZE
|
|
|
| for batch_start in tqdm(range(start_idx, total, EMBED_BATCH_SIZE)):
|
| batch_end = min(batch_start + EMBED_BATCH_SIZE, total)
|
| batch_items = qas[batch_start:batch_end]
|
| questions = [it.get("question") or it.get("q") or it.get("Question") for it in batch_items]
|
| orig_answers = [it.get("answer") or it.get("a") or it.get("Answer","") for it in batch_items]
|
|
|
|
|
| contexts = retrieve_contexts(questions, top_k=3)
|
|
|
| sim_qas_list, rel_list = retrieve_qas_and_rels(questions)
|
|
|
| for i,q in enumerate(questions):
|
| if not q:
|
| write_checkpoint(batch_start+i+1)
|
| continue
|
| ctx_texts, ctx_metas = contexts[i]
|
| context1, context2, context3 = ctx_texts
|
| meta1 = ctx_metas[0]
|
| prompt_text = build_prompt([context2, context3], rel_list[i], sim_qas_list[i], q)
|
| gold = append_metadata_at_end(orig_answers[i], context1, meta1)
|
|
|
| obj={
|
| "id": batch_start+i,
|
| "question": q,
|
| "prompt": prompt_text,
|
| "gold_answer": gold,
|
| "context1": context1,
|
| "retrieved_qas": sim_qas_list[i],
|
| "relation_text": rel_list[i]
|
| }
|
| prompts_accum.append(obj)
|
|
|
|
|
| if len(prompts_accum)>=SAVE_BATCH_SIZE:
|
| out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
|
| out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2),encoding='utf-8')
|
| batch_count+=1
|
| prompts_accum=[]
|
|
|
| write_checkpoint(batch_start+i+1)
|
|
|
|
|
| if prompts_accum:
|
| out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
|
| out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2))
|
|
|
| OUT_DIR.joinpath("special_tokens_used.txt").write_text("\n".join(STRUCTURAL_TOKENS))
|
| print("[DONE] All prompts processed.")
|
|
|
| if __name__=="__main__":
|
| main()
|
|
|