File size: 5,365 Bytes
e078b1d
 
 
 
 
 
 
 
a543f4f
7976e9d
 
a543f4f
 
7976e9d
 
 
a543f4f
 
7976e9d
 
 
 
a543f4f
 
7976e9d
 
a543f4f
 
 
e078b1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7976e9d
e078b1d
 
 
 
 
 
 
 
a543f4f
 
 
 
 
 
e078b1d
 
d703e0b
7976e9d
 
 
d703e0b
e078b1d
 
 
d703e0b
e078b1d
 
7976e9d
 
d703e0b
e078b1d
55729b3
d703e0b
7976e9d
 
 
 
 
 
 
 
 
 
a543f4f
55729b3
 
 
 
 
a543f4f
55729b3
 
 
d703e0b
55729b3
e078b1d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from src.data.utils import load_config

# ── Per-model instruction prefixes ────────────────────────────────────────────
# Changed prompting to be highly professional, requesting a "classy",
# high-impact executive tone suitable for official intelligence reports.
_MODEL_PROMPTS: dict[str, str] = {
    "bart_large_cnn": (
        "Re-write the following traffic event into a highly professional executive "
        "incident brief. Focus on creating an impactful, formal summary highlighting "
        "severity and operational disruption: "
    ),
    "flan_t5_small": (
        "Task: Create a professional, high-impact Executive Traffic Intelligence Brief "
        "from the following incident. Emphasize severity, exact location, and direct "
        "consequences in a formal tone. "
        "Incident details: "
    ),
    "pegasus_cnn": (
        "Generate a formal, impactful Traffic Intelligence Report summarizing the key "
        "operational facts from this incident: "
    ),
}

@dataclass
class GenerationConfig:
    max_input_tokens: int
    min_new_tokens: int
    max_new_tokens: int
    num_beams: int
    length_penalty: float
    no_repeat_ngram_size: int
    early_stopping: bool
    prompt_prefix: str = ""

def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"

@lru_cache(maxsize=8)
def load_tokenizer_and_model(hf_name: str) -> tuple[Any, Any]:
    tokenizer = AutoTokenizer.from_pretrained(hf_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(hf_name)
    model.to(get_device())
    model.eval()
    return tokenizer, model

def build_generation_config(model_name: str, config_path: str = "config.yaml"):
    cfg = load_config(config_path)
    model_cfg = cfg["models"][model_name]
    gen_cfg = cfg["generation"]
    return model_cfg["hf_name"], GenerationConfig(
        max_input_tokens=model_cfg.get("max_input_tokens", gen_cfg["default_max_input_tokens"]),
        min_new_tokens=gen_cfg["default_min_new_tokens"],
        max_new_tokens=gen_cfg["default_max_new_tokens"],
        num_beams=gen_cfg["num_beams"],
        length_penalty=1.0,  # Reverted length_penalty to 1.0 (defaults) for natural flow
        no_repeat_ngram_size=gen_cfg["no_repeat_ngram_size"],
        early_stopping=gen_cfg["early_stopping"],
        prompt_prefix=model_cfg.get("prompt_prefix", ""),
    )

def generate_summary(text: str, model_name: str, config_path: str = "config.yaml", max_new_tokens: int | None = None) -> str:
    hf_name, gen = build_generation_config(model_name, config_path)
    tokenizer, model = load_tokenizer_and_model(hf_name)

    # Use model-specific rewriting instruction if available, else fall back to config prefix.
    instruction = _MODEL_PROMPTS.get(model_name, gen.prompt_prefix)
    clean_text = " ".join(str(text).split())
    source_text = f"{instruction}{clean_text}"

    encoded = tokenizer(source_text, truncation=True, max_length=gen.max_input_tokens, return_tensors="pt")
    encoded = {k: v.to(get_device()) for k, v in encoded.items()}

    # Limit to max_tokens configured. The previous dynamic strict limit forced the models
    # to behave weirdly or copy, instead let the model use its own stopping logic.
    actual_max_tokens = max_new_tokens or gen.max_new_tokens

    with torch.inference_mode():
        output_ids = model.generate(
            **encoded,
            min_new_tokens=gen.min_new_tokens,
            max_new_tokens=actual_max_tokens,
            num_beams=gen.num_beams,
            length_penalty=gen.length_penalty,
            no_repeat_ngram_size=gen.no_repeat_ngram_size,
            early_stopping=True,
        )
    output_text = " ".join(tokenizer.decode(output_ids[0], skip_special_tokens=True).split())

    # Strip the instruction template echo
    for prefix in _MODEL_PROMPTS.values():
        if output_text.lower().startswith(prefix.replace("Task: ", "").lower().strip()[:20]):
            output_text = output_text[len(prefix):].strip()

    # Generic stripping of prefixes the models sometimes generate
    output_text = output_text.replace("Executive Incident Brief:", "")
    output_text = output_text.replace("Traffic Intelligence Report:", "")
    output_text = output_text.replace("Incident report:", "")

    # Strip known hallucinations
    hallucinations = [
        "For confidential support call the Samaritans in the UK on 08457 90 90 90, visit a local Samaritans branch or click here for details.",
        "For confidential support call the Samaritans",
        "The cause of the collision has not been determined",
        "The incident is under investigation by Dubai Police.",
        "The incident is currently under investigation and no further details have been released.",
    ]
    for h in hallucinations:
        output_text = output_text.replace(h, "").strip()

    return " ".join(output_text.split())

def available_abstractive_models(config_path: str = "config.yaml") -> List[str]:
    cfg = load_config(config_path)
    return [name for name, meta in cfg["models"].items() if meta.get("enabled", False)]