| """Model and metric utilities for ReACC-style generation."""
|
|
|
| from __future__ import annotations
|
| from dataset import SPECIAL_TOKENS, build_prompt
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| import torch
|
| from typing import Dict, Optional, Sequence
|
| from difflib import SequenceMatcher
|
| import math
|
|
|
| import os
|
| import sys
|
| CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| if CURRENT_DIR not in sys.path:
|
| sys.path.insert(0, CURRENT_DIR)
|
|
|
|
|
| def load_model_and_tokenizer(model_name_or_path: str):
|
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
| tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
|
|
|
| model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
| model.resize_token_embeddings(len(tokenizer))
|
| return tokenizer, model
|
|
|
|
|
| @torch.no_grad()
|
| def generate_completion(
|
| model,
|
| tokenizer,
|
| retrieved: str,
|
| context: str,
|
| device: torch.device,
|
| max_length: int = 384,
|
| max_new_tokens: int = 64,
|
| do_sample: bool = False,
|
| temperature: float = 0.2,
|
| top_p: float = 0.95,
|
| stop_strings: Optional[Sequence[str]] = None,
|
| ) -> str:
|
| prompt = build_prompt(retrieved, context)
|
| inputs = tokenizer(prompt, return_tensors="pt",
|
| truncation=True, max_length=max_length)
|
| inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
| generation_kwargs = dict(
|
| max_new_tokens=max_new_tokens,
|
| do_sample=do_sample,
|
| num_beams=1,
|
| pad_token_id=tokenizer.pad_token_id,
|
| eos_token_id=tokenizer.eos_token_id,
|
| )
|
| if do_sample:
|
| generation_kwargs["temperature"] = temperature
|
| generation_kwargs["top_p"] = top_p
|
|
|
| output = model.generate(**inputs, **generation_kwargs)
|
| full = tokenizer.decode(output[0], skip_special_tokens=False)
|
| prompt_text = tokenizer.decode(
|
| inputs["input_ids"][0], skip_special_tokens=False)
|
| generated = full[len(prompt_text):]
|
|
|
| if stop_strings:
|
| cut = None
|
| for s in stop_strings:
|
| pos = generated.find(s)
|
| if pos >= 0:
|
| cut = pos if cut is None else min(cut, pos)
|
| if cut is not None:
|
| generated = generated[:cut]
|
| return generated
|
|
|
|
|
| def exact_match(pred: str, gold: str) -> float:
|
| return 1.0 if pred.strip() == gold.strip() else 0.0
|
|
|
|
|
| def edit_similarity(pred: str, gold: str) -> float:
|
| return SequenceMatcher(None, pred.strip(), gold.strip()).ratio() * 100.0
|
|
|
|
|
| def perplexity_from_loss(loss_value: float) -> float:
|
| if loss_value >= 20:
|
| return float("inf")
|
| return math.exp(loss_value)
|
|
|
|
|
| def evaluate_generation(preds: Sequence[str], golds: Sequence[str]) -> Dict[str, float]:
|
| assert len(preds) == len(golds)
|
| if not preds:
|
| return {"exact_match": 0.0, "edit_similarity": 0.0}
|
| em = sum(exact_match(p, g) for p, g in zip(preds, golds)) / len(preds)
|
| es = sum(edit_similarity(p, g) for p, g in zip(preds, golds)) / len(preds)
|
| return {"exact_match": em * 100.0, "edit_similarity": es}
|
|
|