UIT.CS2229.ReACC / model_utils.py
TranTruongMMCII's picture
Upload 3 files
7a911f3 verified
"""Model and metric utilities for ReACC-style generation."""
from __future__ import annotations
from dataset import SPECIAL_TOKENS, build_prompt
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import Dict, Optional, Sequence
from difflib import SequenceMatcher
import math
import os
import sys
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
if CURRENT_DIR not in sys.path:
sys.path.insert(0, CURRENT_DIR)
def load_model_and_tokenizer(model_name_or_path: str):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model.resize_token_embeddings(len(tokenizer))
return tokenizer, model
@torch.no_grad()
def generate_completion(
model,
tokenizer,
retrieved: str,
context: str,
device: torch.device,
max_length: int = 384,
max_new_tokens: int = 64,
do_sample: bool = False,
temperature: float = 0.2,
top_p: float = 0.95,
stop_strings: Optional[Sequence[str]] = None,
) -> str:
prompt = build_prompt(retrieved, context)
inputs = tokenizer(prompt, return_tensors="pt",
truncation=True, max_length=max_length)
inputs = {k: v.to(device) for k, v in inputs.items()}
generation_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
num_beams=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
if do_sample:
generation_kwargs["temperature"] = temperature
generation_kwargs["top_p"] = top_p
output = model.generate(**inputs, **generation_kwargs)
full = tokenizer.decode(output[0], skip_special_tokens=False)
prompt_text = tokenizer.decode(
inputs["input_ids"][0], skip_special_tokens=False)
generated = full[len(prompt_text):]
if stop_strings:
cut = None
for s in stop_strings:
pos = generated.find(s)
if pos >= 0:
cut = pos if cut is None else min(cut, pos)
if cut is not None:
generated = generated[:cut]
return generated
def exact_match(pred: str, gold: str) -> float:
return 1.0 if pred.strip() == gold.strip() else 0.0
def edit_similarity(pred: str, gold: str) -> float:
return SequenceMatcher(None, pred.strip(), gold.strip()).ratio() * 100.0
def perplexity_from_loss(loss_value: float) -> float:
if loss_value >= 20:
return float("inf")
return math.exp(loss_value)
def evaluate_generation(preds: Sequence[str], golds: Sequence[str]) -> Dict[str, float]:
assert len(preds) == len(golds)
if not preds:
return {"exact_match": 0.0, "edit_similarity": 0.0}
em = sum(exact_match(p, g) for p, g in zip(preds, golds)) / len(preds)
es = sum(edit_similarity(p, g) for p, g in zip(preds, golds)) / len(preds)
return {"exact_match": em * 100.0, "edit_similarity": es}