Rajeev Ranjan Pandey commited on
Commit
a543f4f
Β·
1 Parent(s): d703e0b

feat: directive prompts per model, BrainCircuit icon, info tooltip light mode fix

Browse files
frontend/src/components/DatasetToggle.jsx CHANGED
@@ -1,15 +1,15 @@
1
- import { Check, Database } from "lucide-react";
2
 
3
  export default function DatasetToggle({ value, onChange }) {
4
  const options = [
5
- { value: "gcc", label: "GCC / UAE", subtitle: "250+ Narrative Samples", flag: "πŸ‡¦πŸ‡ͺ" },
6
- { value: "us", label: "US Accidents", subtitle: "5,000+ Extracted Records", flag: "πŸ‡ΊπŸ‡Έ" }
7
  ];
8
 
9
  return (
10
  <div className="rounded-2xl border border-slate-300 dark:border-white/[0.07] bg-white dark:bg-[#0d1326] p-5 shadow-sm dark:shadow-xl">
11
  <div className="flex items-center gap-2 mb-4 text-[10px] font-bold uppercase tracking-[0.2em] text-slate-400 dark:text-slate-500">
12
- <Database size={12}/> Analysis Dataset
13
  </div>
14
  <div className="flex flex-col gap-3">
15
  {options.map((option) => {
 
1
+ import { Check, BrainCircuit } from "lucide-react";
2
 
3
  export default function DatasetToggle({ value, onChange }) {
4
  const options = [
5
+ { value: "gcc", label: "GCC / UAE", subtitle: "250+ Narrative Samples", flag: "πŸ‡¦πŸ‡ͺ" },
6
+ { value: "us", label: "US Accidents", subtitle: "5,000+ Extracted Records", flag: "πŸ‡ΊπŸ‡Έ" }
7
  ];
8
 
9
  return (
10
  <div className="rounded-2xl border border-slate-300 dark:border-white/[0.07] bg-white dark:bg-[#0d1326] p-5 shadow-sm dark:shadow-xl">
11
  <div className="flex items-center gap-2 mb-4 text-[10px] font-bold uppercase tracking-[0.2em] text-slate-400 dark:text-slate-500">
12
+ <BrainCircuit size={13} className="text-orange-500"/> Available Datasets
13
  </div>
14
  <div className="flex flex-col gap-3">
15
  {options.map((option) => {
frontend/src/components/SummarizerWidget.jsx CHANGED
@@ -223,11 +223,13 @@ export default function SummarizerWidget({
223
  <Icon size={18} />
224
  </span>
225
  <div className="group/tooltip relative">
226
- <div className="flex h-6 w-6 items-center justify-center rounded-full text-slate-400 hover:text-slate-700 transition cursor-help dark:text-slate-500 dark:hover:text-white">
227
- <span className="font-serif italic border border-slate-300 dark:border-slate-600 rounded-full w-4 h-4 flex items-center justify-center text-[10px]">i</span>
 
 
228
  </div>
229
- <div className="absolute right-0 lg:right-auto lg:left-0 top-8 z-50 w-64 opacity-0 scale-95 origin-top-right lg:origin-top-left transition-all group-hover/tooltip:opacity-100 group-hover/tooltip:scale-100 pointer-events-none group-hover/tooltip:pointer-events-auto rounded-xl bg-slate-900 border border-slate-800 p-3 shadow-xl dark:bg-slate-800 dark:border-slate-700">
230
- <p className="text-xs text-slate-300 leading-relaxed font-normal">{model.description}</p>
231
  </div>
232
  </div>
233
  </div>
 
223
  <Icon size={18} />
224
  </span>
225
  <div className="group/tooltip relative">
226
+ <div className="flex h-6 w-6 items-center justify-center rounded-full hover:bg-orange-50 dark:hover:bg-white/10 transition cursor-help">
227
+ <span className="font-serif italic border-2 border-slate-400 dark:border-slate-500 text-slate-500 dark:text-slate-400 hover:border-orange-500 hover:text-orange-600 dark:hover:text-white rounded-full w-4 h-4 flex items-center justify-center text-[10px] transition">
228
+ i
229
+ </span>
230
  </div>
231
+ <div className="absolute right-0 lg:right-auto lg:left-0 top-8 z-50 w-64 opacity-0 scale-95 origin-top-right lg:origin-top-left transition-all group-hover/tooltip:opacity-100 group-hover/tooltip:scale-100 pointer-events-none group-hover/tooltip:pointer-events-auto rounded-xl bg-white dark:bg-slate-800 border border-slate-200 dark:border-slate-700 p-3 shadow-xl">
232
+ <p className="text-xs text-slate-700 dark:text-slate-300 leading-relaxed font-normal">{model.description}</p>
233
  </div>
234
  </div>
235
  </div>
src/models/abstractive.py CHANGED
@@ -6,6 +6,24 @@ import torch
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from src.data.utils import load_config
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  @dataclass
10
  class GenerationConfig:
11
  max_input_tokens: int
@@ -46,14 +64,18 @@ def build_generation_config(model_name: str, config_path: str = "config.yaml"):
46
  def generate_summary(text: str, model_name: str, config_path: str = "config.yaml", max_new_tokens: int | None = None) -> str:
47
  hf_name, gen = build_generation_config(model_name, config_path)
48
  tokenizer, model = load_tokenizer_and_model(hf_name)
49
- source_text = f"{gen.prompt_prefix}{' '.join(str(text).split())}"
 
 
 
 
 
50
  encoded = tokenizer(source_text, truncation=True, max_length=gen.max_input_tokens, return_tensors="pt")
51
  encoded = {k: v.to(get_device()) for k, v in encoded.items()}
52
 
53
- # Dynamically cap max_new_tokens at 55 % of input length to enforce compression.
54
- # This prevents the model from echoing short inputs verbatim.
55
- input_len = encoded["input_ids"].shape[-1]
56
- dynamic_max = max(gen.min_new_tokens, min(int(input_len * 0.55), gen.max_new_tokens))
57
  actual_max_tokens = max_new_tokens or dynamic_max
58
 
59
  with torch.inference_mode():
@@ -62,19 +84,19 @@ def generate_summary(text: str, model_name: str, config_path: str = "config.yaml
62
  min_new_tokens=gen.min_new_tokens,
63
  max_new_tokens=actual_max_tokens,
64
  num_beams=gen.num_beams,
65
- length_penalty=2.0, # strongly prefers shorter outputs
66
- no_repeat_ngram_size=4, # blocks 4-gram repetition / copy
67
  early_stopping=True,
68
  )
69
  output_text = " ".join(tokenizer.decode(output_ids[0], skip_special_tokens=True).split())
70
 
71
- # Post-processing: strip known hallucinations
72
  hallucinations = [
73
  "For confidential support call the Samaritans in the UK on 08457 90 90 90, visit a local Samaritans branch or click here for details.",
74
  "For confidential support call the Samaritans",
75
  "The cause of the collision has not been determined",
76
  "The incident is under investigation by Dubai Police.",
77
- "The incident is currently under investigation and no further details have been released."
78
  ]
79
  for h in hallucinations:
80
  output_text = output_text.replace(h, "").strip()
 
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from src.data.utils import load_config
8
 
9
+ # ── Per-model instruction prefixes ────────────────────────────────────────────
10
+ # Prepended to raw incident text so models rewrite instead of echo.
11
+ _MODEL_PROMPTS: dict[str, str] = {
12
+ "bart_large_cnn": (
13
+ "Generate a concise traffic incident summary. "
14
+ "Report only: location, incident type, severity, and road impact. "
15
+ "Be brief. Incident report: "
16
+ ),
17
+ "flan_t5_small": (
18
+ "Write a one-sentence traffic incident summary covering location, "
19
+ "incident type, severity level, and road impact in under 35 words. "
20
+ "Traffic report: "
21
+ ),
22
+ "pegasus_cnn": (
23
+ "Summarize the key facts from this traffic incident in one compact sentence: "
24
+ ),
25
+ }
26
+
27
  @dataclass
28
  class GenerationConfig:
29
  max_input_tokens: int
 
64
  def generate_summary(text: str, model_name: str, config_path: str = "config.yaml", max_new_tokens: int | None = None) -> str:
65
  hf_name, gen = build_generation_config(model_name, config_path)
66
  tokenizer, model = load_tokenizer_and_model(hf_name)
67
+
68
+ # Use model-specific rewriting instruction if available, else fall back to config prefix.
69
+ instruction = _MODEL_PROMPTS.get(model_name, gen.prompt_prefix)
70
+ clean_text = " ".join(str(text).split())
71
+ source_text = f"{instruction}{clean_text}"
72
+
73
  encoded = tokenizer(source_text, truncation=True, max_length=gen.max_input_tokens, return_tensors="pt")
74
  encoded = {k: v.to(get_device()) for k, v in encoded.items()}
75
 
76
+ # Dynamic cap: limit output to 50 % of raw input token count to force compression.
77
+ raw_len = tokenizer(clean_text, return_tensors="pt")["input_ids"].shape[-1]
78
+ dynamic_max = max(gen.min_new_tokens, min(int(raw_len * 0.50), gen.max_new_tokens))
 
79
  actual_max_tokens = max_new_tokens or dynamic_max
80
 
81
  with torch.inference_mode():
 
84
  min_new_tokens=gen.min_new_tokens,
85
  max_new_tokens=actual_max_tokens,
86
  num_beams=gen.num_beams,
87
+ length_penalty=3.0, # strongly prefers concise outputs
88
+ no_repeat_ngram_size=4, # blocks 4-gram copying from input
89
  early_stopping=True,
90
  )
91
  output_text = " ".join(tokenizer.decode(output_ids[0], skip_special_tokens=True).split())
92
 
93
+ # Strip known hallucinations
94
  hallucinations = [
95
  "For confidential support call the Samaritans in the UK on 08457 90 90 90, visit a local Samaritans branch or click here for details.",
96
  "For confidential support call the Samaritans",
97
  "The cause of the collision has not been determined",
98
  "The incident is under investigation by Dubai Police.",
99
+ "The incident is currently under investigation and no further details have been released.",
100
  ]
101
  for h in hallucinations:
102
  output_text = output_text.replace(h, "").strip()