Spaces:
Running
Running
Rajeev Ranjan Pandey commited on
Commit Β·
a543f4f
1
Parent(s): d703e0b
feat: directive prompts per model, BrainCircuit icon, info tooltip light mode fix
Browse files
frontend/src/components/DatasetToggle.jsx
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
import { Check,
|
| 2 |
|
| 3 |
export default function DatasetToggle({ value, onChange }) {
|
| 4 |
const options = [
|
| 5 |
-
{ value: "gcc", label: "GCC / UAE",
|
| 6 |
-
{ value: "us", label: "US Accidents", subtitle: "5,000+ Extracted Records",
|
| 7 |
];
|
| 8 |
|
| 9 |
return (
|
| 10 |
<div className="rounded-2xl border border-slate-300 dark:border-white/[0.07] bg-white dark:bg-[#0d1326] p-5 shadow-sm dark:shadow-xl">
|
| 11 |
<div className="flex items-center gap-2 mb-4 text-[10px] font-bold uppercase tracking-[0.2em] text-slate-400 dark:text-slate-500">
|
| 12 |
-
<
|
| 13 |
</div>
|
| 14 |
<div className="flex flex-col gap-3">
|
| 15 |
{options.map((option) => {
|
|
|
|
| 1 |
+
import { Check, BrainCircuit } from "lucide-react";
|
| 2 |
|
| 3 |
export default function DatasetToggle({ value, onChange }) {
|
| 4 |
const options = [
|
| 5 |
+
{ value: "gcc", label: "GCC / UAE", subtitle: "250+ Narrative Samples", flag: "π¦πͺ" },
|
| 6 |
+
{ value: "us", label: "US Accidents", subtitle: "5,000+ Extracted Records", flag: "πΊπΈ" }
|
| 7 |
];
|
| 8 |
|
| 9 |
return (
|
| 10 |
<div className="rounded-2xl border border-slate-300 dark:border-white/[0.07] bg-white dark:bg-[#0d1326] p-5 shadow-sm dark:shadow-xl">
|
| 11 |
<div className="flex items-center gap-2 mb-4 text-[10px] font-bold uppercase tracking-[0.2em] text-slate-400 dark:text-slate-500">
|
| 12 |
+
<BrainCircuit size={13} className="text-orange-500"/> Available Datasets
|
| 13 |
</div>
|
| 14 |
<div className="flex flex-col gap-3">
|
| 15 |
{options.map((option) => {
|
frontend/src/components/SummarizerWidget.jsx
CHANGED
|
@@ -223,11 +223,13 @@ export default function SummarizerWidget({
|
|
| 223 |
<Icon size={18} />
|
| 224 |
</span>
|
| 225 |
<div className="group/tooltip relative">
|
| 226 |
-
<div className="flex h-6 w-6 items-center justify-center rounded-full
|
| 227 |
-
<span className="font-serif italic border border-slate-
|
|
|
|
|
|
|
| 228 |
</div>
|
| 229 |
-
<div className="absolute right-0 lg:right-auto lg:left-0 top-8 z-50 w-64 opacity-0 scale-95 origin-top-right lg:origin-top-left transition-all group-hover/tooltip:opacity-100 group-hover/tooltip:scale-100 pointer-events-none group-hover/tooltip:pointer-events-auto rounded-xl bg-
|
| 230 |
-
<p className="text-xs text-slate-300 leading-relaxed font-normal">{model.description}</p>
|
| 231 |
</div>
|
| 232 |
</div>
|
| 233 |
</div>
|
|
|
|
| 223 |
<Icon size={18} />
|
| 224 |
</span>
|
| 225 |
<div className="group/tooltip relative">
|
| 226 |
+
<div className="flex h-6 w-6 items-center justify-center rounded-full hover:bg-orange-50 dark:hover:bg-white/10 transition cursor-help">
|
| 227 |
+
<span className="font-serif italic border-2 border-slate-400 dark:border-slate-500 text-slate-500 dark:text-slate-400 hover:border-orange-500 hover:text-orange-600 dark:hover:text-white rounded-full w-4 h-4 flex items-center justify-center text-[10px] transition">
|
| 228 |
+
i
|
| 229 |
+
</span>
|
| 230 |
</div>
|
| 231 |
+
<div className="absolute right-0 lg:right-auto lg:left-0 top-8 z-50 w-64 opacity-0 scale-95 origin-top-right lg:origin-top-left transition-all group-hover/tooltip:opacity-100 group-hover/tooltip:scale-100 pointer-events-none group-hover/tooltip:pointer-events-auto rounded-xl bg-white dark:bg-slate-800 border border-slate-200 dark:border-slate-700 p-3 shadow-xl">
|
| 232 |
+
<p className="text-xs text-slate-700 dark:text-slate-300 leading-relaxed font-normal">{model.description}</p>
|
| 233 |
</div>
|
| 234 |
</div>
|
| 235 |
</div>
|
src/models/abstractive.py
CHANGED
|
@@ -6,6 +6,24 @@ import torch
|
|
| 6 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 7 |
from src.data.utils import load_config
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
@dataclass
|
| 10 |
class GenerationConfig:
|
| 11 |
max_input_tokens: int
|
|
@@ -46,14 +64,18 @@ def build_generation_config(model_name: str, config_path: str = "config.yaml"):
|
|
| 46 |
def generate_summary(text: str, model_name: str, config_path: str = "config.yaml", max_new_tokens: int | None = None) -> str:
|
| 47 |
hf_name, gen = build_generation_config(model_name, config_path)
|
| 48 |
tokenizer, model = load_tokenizer_and_model(hf_name)
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
encoded = tokenizer(source_text, truncation=True, max_length=gen.max_input_tokens, return_tensors="pt")
|
| 51 |
encoded = {k: v.to(get_device()) for k, v in encoded.items()}
|
| 52 |
|
| 53 |
-
#
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
dynamic_max = max(gen.min_new_tokens, min(int(input_len * 0.55), gen.max_new_tokens))
|
| 57 |
actual_max_tokens = max_new_tokens or dynamic_max
|
| 58 |
|
| 59 |
with torch.inference_mode():
|
|
@@ -62,19 +84,19 @@ def generate_summary(text: str, model_name: str, config_path: str = "config.yaml
|
|
| 62 |
min_new_tokens=gen.min_new_tokens,
|
| 63 |
max_new_tokens=actual_max_tokens,
|
| 64 |
num_beams=gen.num_beams,
|
| 65 |
-
length_penalty=
|
| 66 |
-
no_repeat_ngram_size=4,
|
| 67 |
early_stopping=True,
|
| 68 |
)
|
| 69 |
output_text = " ".join(tokenizer.decode(output_ids[0], skip_special_tokens=True).split())
|
| 70 |
|
| 71 |
-
#
|
| 72 |
hallucinations = [
|
| 73 |
"For confidential support call the Samaritans in the UK on 08457 90 90 90, visit a local Samaritans branch or click here for details.",
|
| 74 |
"For confidential support call the Samaritans",
|
| 75 |
"The cause of the collision has not been determined",
|
| 76 |
"The incident is under investigation by Dubai Police.",
|
| 77 |
-
"The incident is currently under investigation and no further details have been released."
|
| 78 |
]
|
| 79 |
for h in hallucinations:
|
| 80 |
output_text = output_text.replace(h, "").strip()
|
|
|
|
| 6 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 7 |
from src.data.utils import load_config
|
| 8 |
|
| 9 |
+
# ββ Per-model instruction prefixes ββββββββββββββββββββββββββββββββββββββββββββ
|
| 10 |
+
# Prepended to raw incident text so models rewrite instead of echo.
|
| 11 |
+
_MODEL_PROMPTS: dict[str, str] = {
|
| 12 |
+
"bart_large_cnn": (
|
| 13 |
+
"Generate a concise traffic incident summary. "
|
| 14 |
+
"Report only: location, incident type, severity, and road impact. "
|
| 15 |
+
"Be brief. Incident report: "
|
| 16 |
+
),
|
| 17 |
+
"flan_t5_small": (
|
| 18 |
+
"Write a one-sentence traffic incident summary covering location, "
|
| 19 |
+
"incident type, severity level, and road impact in under 35 words. "
|
| 20 |
+
"Traffic report: "
|
| 21 |
+
),
|
| 22 |
+
"pegasus_cnn": (
|
| 23 |
+
"Summarize the key facts from this traffic incident in one compact sentence: "
|
| 24 |
+
),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
@dataclass
|
| 28 |
class GenerationConfig:
|
| 29 |
max_input_tokens: int
|
|
|
|
| 64 |
def generate_summary(text: str, model_name: str, config_path: str = "config.yaml", max_new_tokens: int | None = None) -> str:
|
| 65 |
hf_name, gen = build_generation_config(model_name, config_path)
|
| 66 |
tokenizer, model = load_tokenizer_and_model(hf_name)
|
| 67 |
+
|
| 68 |
+
# Use model-specific rewriting instruction if available, else fall back to config prefix.
|
| 69 |
+
instruction = _MODEL_PROMPTS.get(model_name, gen.prompt_prefix)
|
| 70 |
+
clean_text = " ".join(str(text).split())
|
| 71 |
+
source_text = f"{instruction}{clean_text}"
|
| 72 |
+
|
| 73 |
encoded = tokenizer(source_text, truncation=True, max_length=gen.max_input_tokens, return_tensors="pt")
|
| 74 |
encoded = {k: v.to(get_device()) for k, v in encoded.items()}
|
| 75 |
|
| 76 |
+
# Dynamic cap: limit output to 50 % of raw input token count to force compression.
|
| 77 |
+
raw_len = tokenizer(clean_text, return_tensors="pt")["input_ids"].shape[-1]
|
| 78 |
+
dynamic_max = max(gen.min_new_tokens, min(int(raw_len * 0.50), gen.max_new_tokens))
|
|
|
|
| 79 |
actual_max_tokens = max_new_tokens or dynamic_max
|
| 80 |
|
| 81 |
with torch.inference_mode():
|
|
|
|
| 84 |
min_new_tokens=gen.min_new_tokens,
|
| 85 |
max_new_tokens=actual_max_tokens,
|
| 86 |
num_beams=gen.num_beams,
|
| 87 |
+
length_penalty=3.0, # strongly prefers concise outputs
|
| 88 |
+
no_repeat_ngram_size=4, # blocks 4-gram copying from input
|
| 89 |
early_stopping=True,
|
| 90 |
)
|
| 91 |
output_text = " ".join(tokenizer.decode(output_ids[0], skip_special_tokens=True).split())
|
| 92 |
|
| 93 |
+
# Strip known hallucinations
|
| 94 |
hallucinations = [
|
| 95 |
"For confidential support call the Samaritans in the UK on 08457 90 90 90, visit a local Samaritans branch or click here for details.",
|
| 96 |
"For confidential support call the Samaritans",
|
| 97 |
"The cause of the collision has not been determined",
|
| 98 |
"The incident is under investigation by Dubai Police.",
|
| 99 |
+
"The incident is currently under investigation and no further details have been released.",
|
| 100 |
]
|
| 101 |
for h in hallucinations:
|
| 102 |
output_text = output_text.replace(h, "").strip()
|