CUD-Traffic-AI / src /models /abstractive.py
Rajeev Ranjan Pandey
fix: restore working copy/save buttons in dark mode and elevate model prompt styling
7976e9d
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from src.data.utils import load_config
# ── Per-model instruction prefixes ────────────────────────────────────────────
# Changed prompting to be highly professional, requesting a "classy",
# high-impact executive tone suitable for official intelligence reports.
_MODEL_PROMPTS: dict[str, str] = {
"bart_large_cnn": (
"Re-write the following traffic event into a highly professional executive "
"incident brief. Focus on creating an impactful, formal summary highlighting "
"severity and operational disruption: "
),
"flan_t5_small": (
"Task: Create a professional, high-impact Executive Traffic Intelligence Brief "
"from the following incident. Emphasize severity, exact location, and direct "
"consequences in a formal tone. "
"Incident details: "
),
"pegasus_cnn": (
"Generate a formal, impactful Traffic Intelligence Report summarizing the key "
"operational facts from this incident: "
),
}
@dataclass
class GenerationConfig:
max_input_tokens: int
min_new_tokens: int
max_new_tokens: int
num_beams: int
length_penalty: float
no_repeat_ngram_size: int
early_stopping: bool
prompt_prefix: str = ""
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache(maxsize=8)
def load_tokenizer_and_model(hf_name: str) -> tuple[Any, Any]:
tokenizer = AutoTokenizer.from_pretrained(hf_name)
model = AutoModelForSeq2SeqLM.from_pretrained(hf_name)
model.to(get_device())
model.eval()
return tokenizer, model
def build_generation_config(model_name: str, config_path: str = "config.yaml"):
cfg = load_config(config_path)
model_cfg = cfg["models"][model_name]
gen_cfg = cfg["generation"]
return model_cfg["hf_name"], GenerationConfig(
max_input_tokens=model_cfg.get("max_input_tokens", gen_cfg["default_max_input_tokens"]),
min_new_tokens=gen_cfg["default_min_new_tokens"],
max_new_tokens=gen_cfg["default_max_new_tokens"],
num_beams=gen_cfg["num_beams"],
length_penalty=1.0, # Reverted length_penalty to 1.0 (defaults) for natural flow
no_repeat_ngram_size=gen_cfg["no_repeat_ngram_size"],
early_stopping=gen_cfg["early_stopping"],
prompt_prefix=model_cfg.get("prompt_prefix", ""),
)
def generate_summary(text: str, model_name: str, config_path: str = "config.yaml", max_new_tokens: int | None = None) -> str:
hf_name, gen = build_generation_config(model_name, config_path)
tokenizer, model = load_tokenizer_and_model(hf_name)
# Use model-specific rewriting instruction if available, else fall back to config prefix.
instruction = _MODEL_PROMPTS.get(model_name, gen.prompt_prefix)
clean_text = " ".join(str(text).split())
source_text = f"{instruction}{clean_text}"
encoded = tokenizer(source_text, truncation=True, max_length=gen.max_input_tokens, return_tensors="pt")
encoded = {k: v.to(get_device()) for k, v in encoded.items()}
# Limit to max_tokens configured. The previous dynamic strict limit forced the models
# to behave weirdly or copy, instead let the model use its own stopping logic.
actual_max_tokens = max_new_tokens or gen.max_new_tokens
with torch.inference_mode():
output_ids = model.generate(
**encoded,
min_new_tokens=gen.min_new_tokens,
max_new_tokens=actual_max_tokens,
num_beams=gen.num_beams,
length_penalty=gen.length_penalty,
no_repeat_ngram_size=gen.no_repeat_ngram_size,
early_stopping=True,
)
output_text = " ".join(tokenizer.decode(output_ids[0], skip_special_tokens=True).split())
# Strip the instruction template echo
for prefix in _MODEL_PROMPTS.values():
if output_text.lower().startswith(prefix.replace("Task: ", "").lower().strip()[:20]):
output_text = output_text[len(prefix):].strip()
# Generic stripping of prefixes the models sometimes generate
output_text = output_text.replace("Executive Incident Brief:", "")
output_text = output_text.replace("Traffic Intelligence Report:", "")
output_text = output_text.replace("Incident report:", "")
# Strip known hallucinations
hallucinations = [
"For confidential support call the Samaritans in the UK on 08457 90 90 90, visit a local Samaritans branch or click here for details.",
"For confidential support call the Samaritans",
"The cause of the collision has not been determined",
"The incident is under investigation by Dubai Police.",
"The incident is currently under investigation and no further details have been released.",
]
for h in hallucinations:
output_text = output_text.replace(h, "").strip()
return " ".join(output_text.split())
def available_abstractive_models(config_path: str = "config.yaml") -> List[str]:
cfg = load_config(config_path)
return [name for name, meta in cfg["models"].items() if meta.get("enabled", False)]