Spaces:
Running
Running
File size: 5,365 Bytes
e078b1d a543f4f 7976e9d a543f4f 7976e9d a543f4f 7976e9d a543f4f 7976e9d a543f4f e078b1d 7976e9d e078b1d a543f4f e078b1d d703e0b 7976e9d d703e0b e078b1d d703e0b e078b1d 7976e9d d703e0b e078b1d 55729b3 d703e0b 7976e9d a543f4f 55729b3 a543f4f 55729b3 d703e0b 55729b3 e078b1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from src.data.utils import load_config
# ββ Per-model instruction prefixes ββββββββββββββββββββββββββββββββββββββββββββ
# Changed prompting to be highly professional, requesting a "classy",
# high-impact executive tone suitable for official intelligence reports.
_MODEL_PROMPTS: dict[str, str] = {
"bart_large_cnn": (
"Re-write the following traffic event into a highly professional executive "
"incident brief. Focus on creating an impactful, formal summary highlighting "
"severity and operational disruption: "
),
"flan_t5_small": (
"Task: Create a professional, high-impact Executive Traffic Intelligence Brief "
"from the following incident. Emphasize severity, exact location, and direct "
"consequences in a formal tone. "
"Incident details: "
),
"pegasus_cnn": (
"Generate a formal, impactful Traffic Intelligence Report summarizing the key "
"operational facts from this incident: "
),
}
@dataclass
class GenerationConfig:
max_input_tokens: int
min_new_tokens: int
max_new_tokens: int
num_beams: int
length_penalty: float
no_repeat_ngram_size: int
early_stopping: bool
prompt_prefix: str = ""
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache(maxsize=8)
def load_tokenizer_and_model(hf_name: str) -> tuple[Any, Any]:
tokenizer = AutoTokenizer.from_pretrained(hf_name)
model = AutoModelForSeq2SeqLM.from_pretrained(hf_name)
model.to(get_device())
model.eval()
return tokenizer, model
def build_generation_config(model_name: str, config_path: str = "config.yaml"):
cfg = load_config(config_path)
model_cfg = cfg["models"][model_name]
gen_cfg = cfg["generation"]
return model_cfg["hf_name"], GenerationConfig(
max_input_tokens=model_cfg.get("max_input_tokens", gen_cfg["default_max_input_tokens"]),
min_new_tokens=gen_cfg["default_min_new_tokens"],
max_new_tokens=gen_cfg["default_max_new_tokens"],
num_beams=gen_cfg["num_beams"],
length_penalty=1.0, # Reverted length_penalty to 1.0 (defaults) for natural flow
no_repeat_ngram_size=gen_cfg["no_repeat_ngram_size"],
early_stopping=gen_cfg["early_stopping"],
prompt_prefix=model_cfg.get("prompt_prefix", ""),
)
def generate_summary(text: str, model_name: str, config_path: str = "config.yaml", max_new_tokens: int | None = None) -> str:
hf_name, gen = build_generation_config(model_name, config_path)
tokenizer, model = load_tokenizer_and_model(hf_name)
# Use model-specific rewriting instruction if available, else fall back to config prefix.
instruction = _MODEL_PROMPTS.get(model_name, gen.prompt_prefix)
clean_text = " ".join(str(text).split())
source_text = f"{instruction}{clean_text}"
encoded = tokenizer(source_text, truncation=True, max_length=gen.max_input_tokens, return_tensors="pt")
encoded = {k: v.to(get_device()) for k, v in encoded.items()}
# Limit to max_tokens configured. The previous dynamic strict limit forced the models
# to behave weirdly or copy, instead let the model use its own stopping logic.
actual_max_tokens = max_new_tokens or gen.max_new_tokens
with torch.inference_mode():
output_ids = model.generate(
**encoded,
min_new_tokens=gen.min_new_tokens,
max_new_tokens=actual_max_tokens,
num_beams=gen.num_beams,
length_penalty=gen.length_penalty,
no_repeat_ngram_size=gen.no_repeat_ngram_size,
early_stopping=True,
)
output_text = " ".join(tokenizer.decode(output_ids[0], skip_special_tokens=True).split())
# Strip the instruction template echo
for prefix in _MODEL_PROMPTS.values():
if output_text.lower().startswith(prefix.replace("Task: ", "").lower().strip()[:20]):
output_text = output_text[len(prefix):].strip()
# Generic stripping of prefixes the models sometimes generate
output_text = output_text.replace("Executive Incident Brief:", "")
output_text = output_text.replace("Traffic Intelligence Report:", "")
output_text = output_text.replace("Incident report:", "")
# Strip known hallucinations
hallucinations = [
"For confidential support call the Samaritans in the UK on 08457 90 90 90, visit a local Samaritans branch or click here for details.",
"For confidential support call the Samaritans",
"The cause of the collision has not been determined",
"The incident is under investigation by Dubai Police.",
"The incident is currently under investigation and no further details have been released.",
]
for h in hallucinations:
output_text = output_text.replace(h, "").strip()
return " ".join(output_text.split())
def available_abstractive_models(config_path: str = "config.yaml") -> List[str]:
cfg = load_config(config_path)
return [name for name, meta in cfg["models"].items() if meta.get("enabled", False)]
|