"""Model and metric utilities for ReACC-style generation.""" from __future__ import annotations from dataset import SPECIAL_TOKENS, build_prompt from transformers import AutoModelForCausalLM, AutoTokenizer import torch from typing import Dict, Optional, Sequence from difflib import SequenceMatcher import math import os import sys CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) if CURRENT_DIR not in sys.path: sys.path.insert(0, CURRENT_DIR) def load_model_and_tokenizer(model_name_or_path: str): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) model = AutoModelForCausalLM.from_pretrained(model_name_or_path) model.resize_token_embeddings(len(tokenizer)) return tokenizer, model @torch.no_grad() def generate_completion( model, tokenizer, retrieved: str, context: str, device: torch.device, max_length: int = 384, max_new_tokens: int = 64, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95, stop_strings: Optional[Sequence[str]] = None, ) -> str: prompt = build_prompt(retrieved, context) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length) inputs = {k: v.to(device) for k, v in inputs.items()} generation_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=do_sample, num_beams=1, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) if do_sample: generation_kwargs["temperature"] = temperature generation_kwargs["top_p"] = top_p output = model.generate(**inputs, **generation_kwargs) full = tokenizer.decode(output[0], skip_special_tokens=False) prompt_text = tokenizer.decode( inputs["input_ids"][0], skip_special_tokens=False) generated = full[len(prompt_text):] if stop_strings: cut = None for s in stop_strings: pos = generated.find(s) if pos >= 0: cut = pos if cut is None else min(cut, pos) if cut is not None: generated = generated[:cut] return generated def exact_match(pred: str, gold: str) -> float: return 1.0 if pred.strip() == gold.strip() else 0.0 def edit_similarity(pred: str, gold: str) -> float: return SequenceMatcher(None, pred.strip(), gold.strip()).ratio() * 100.0 def perplexity_from_loss(loss_value: float) -> float: if loss_value >= 20: return float("inf") return math.exp(loss_value) def evaluate_generation(preds: Sequence[str], golds: Sequence[str]) -> Dict[str, float]: assert len(preds) == len(golds) if not preds: return {"exact_match": 0.0, "edit_similarity": 0.0} em = sum(exact_match(p, g) for p, g in zip(preds, golds)) / len(preds) es = sum(edit_similarity(p, g) for p, g in zip(preds, golds)) / len(preds) return {"exact_match": em * 100.0, "edit_similarity": es}