import os
import gradio as gr
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
import re
import numpy as np
import networkx as nx
import plotly.graph_objects as go
from typing import List, Dict
from src.utils.get_model import get_summarizer, get_extractive_model, get_extractive_abstractive
from src.preprocessing.edu_sentences import preprocess_external_text
from src.model.baseline_extractive_model import get_trigrams
from typing import List, Dict, Tuple
REPO_ID_baseline_model = "Reality8081/bart-base"
REPO_ID_baseline_model_edu = "Reality8081/bart-base-edu"
REPO_ID_baseline_extractive_model = "Reality8081/bart_extractive"
REPO_ID_baseline_extractive_model_edu = "Reality8081/bart_extractive-edu"
REPO_ID_Extabs_model = "Reality8081/bart-encoder-decoder"
REPO_ID_Extabs_model_edu = "Reality8081/bart-encoder-decoder-edu"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def create_saliency_plot(segment_scores: List[float], selected_indices: List[int], segments: List[str]) -> go.Figure:
"""Tạo biểu đồ cột thể hiện Saliency Score của từng Segment"""
colors = ['#EF4444' if i in selected_indices else '#3B82F6' for i in range(len(segment_scores))]
labels = ["Chọn tóm tắt" if i in selected_indices else "Bỏ qua" for i in range(len(segment_scores))]
# Rút gọn text để hiển thị khi hover chuột
hover_texts = [f"Segment {i}
Score: {score:.3f}
Text: {seg[:60]}..."
for i, (score, seg) in enumerate(zip(segment_scores, segments))]
fig = go.Figure(data=[go.Bar(
x=[f"Seg {i}" for i in range(len(segment_scores))],
y=segment_scores,
marker_color=colors,
text=[f"{s:.2f}" for s in segment_scores],
textposition='auto',
hovertext=hover_texts,
hoverinfo="text"
)])
fig.update_layout(
title="Saliency Scores trên từng Câu/EDU",
xaxis_title="Vị trí Câu / EDU",
yaxis_title="Saliency Score (0-1)",
template="plotly_white",
margin=dict(l=40, r=40, t=40, b=40),
height=350
)
return fig
def model_baseline(prepro_dict: Dict) -> Tuple[str, float, go.Figure]:
text_to_summarize = prepro_dict.get("article", "")
if not text_to_summarize.strip():
return "Lỗi: Không tìm thấy văn bản để tóm tắt."
segmentation_method = prepro_dict.get("segmentation_method")
if segmentation_method == "edu":
repo_id = REPO_ID_baseline_model_edu
else:
repo_id = REPO_ID_baseline_model
summarizer = get_summarizer(repo_id)
summary = summarizer.summarize(text_to_summarize)
return summary, None, None
def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> Tuple[str, float, go.Figure]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
segments = prepro_dict["segments"]
if not segments:
raise ValueError("Không thể phân tách văn bản thành các câu/EDU.")
input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
segmentation_method = prepro_dict.get("segmentation_method")
if segmentation_method == "edu":
repo_id = REPO_ID_baseline_extractive_model_edu
else:
repo_id = REPO_ID_baseline_extractive_model
model = get_extractive_model(repo_id=repo_id, device=device)
model = model.to(torch.float32)
model.eval()
with torch.no_grad(), torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu"):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# Sử dụng Sigmoid đưa logit về khoảng (0, 1) để lấy xác suất
logits = outputs['logits'].to(torch.float32)
probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()
segment_scores = []
current_idx = 1 # Bỏ qua token đặc biệt ở đầu chuỗi
for seg in segments:
# Lấy số lượng token của đoạn hiện tại
seg_len = len(tokenizer.encode(seg, add_special_tokens=False))
end_idx = min(current_idx + seg_len, len(probs))
if current_idx < len(probs):
# Điểm của câu là trung bình cộng xác suất các token bên trong nó
seg_score = np.mean(probs[current_idx:end_idx])
else:
seg_score = 0.0
segment_scores.append(seg_score)
current_idx += seg_len
ranked_indices = np.argsort(segment_scores)[::-1] # Sắp xếp giảm dần theo điểm
selected_indices = []
selected_trigrams = set()
for idx in ranked_indices:
candidate_seg = segments[idx]
candidate_trigrams = get_trigrams(candidate_seg)
# Chỉ chọn nếu segment này không trùng lặp các cụm từ (trigram) so với những đoạn đã chọn
if not candidate_trigrams.intersection(selected_trigrams):
selected_indices.append(idx)
selected_trigrams.update(candidate_trigrams)
# Dừng lại nếu đã gom đủ số câu theo yêu cầu
if len(selected_indices) == top_n:
break
# Sắp xếp lại thứ tự index xuất hiện của các câu trong văn bản gốc để tóm tắt được mạch lạc
selected_indices = sorted(selected_indices)
# Lắp ráp kết quả
extractive_summary = " ".join([segments[i] for i in selected_indices if i < len(segments)])
avg_confidence = float(np.mean([segment_scores[i] for i in selected_indices])) if selected_indices else 0.0
fig = create_saliency_plot(segment_scores, selected_indices, segments)
return extractive_summary, avg_confidence, fig
def model_extractive_abstract(prepro_dict: Dict) -> Tuple[str, float, go.Figure]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
segments = prepro_dict["segments"]
input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
segmentation_method = prepro_dict.get("segmentation_method")
if segmentation_method == "edu":
repo_id = REPO_ID_Extabs_model_edu
else:
repo_id = REPO_ID_Extabs_model
model = get_extractive_abstractive(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
with torch.no_grad():
encoder_outputs = model.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = encoder_outputs.last_hidden_state
ext_logits = model.ext_head(hidden_states).squeeze(-1)
probs = torch.sigmoid(ext_logits).squeeze(0).cpu().numpy()
segment_scores = []
current_idx = 1 # Bỏ qua token đặc biệt ở đầu chuỗi
for seg in segments:
seg_len = len(tokenizer.encode(seg, add_special_tokens=False))
end_idx = min(current_idx + seg_len, len(probs))
if current_idx < len(probs):
seg_score = float(np.mean(probs[current_idx:end_idx]))
else:
seg_score = 0.0
segment_scores.append(seg_score)
current_idx += seg_len
selected_indices = [i for i, score in enumerate(segment_scores) if score >= 0.5]
avg_confidence = float(np.mean([segment_scores[i] for i in selected_indices])) if selected_indices else 0.0
summary_ids = model.generate_summary(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=150,
min_length=40,
num_beams=4,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
fig = create_saliency_plot(segment_scores, selected_indices, segments) # Hiện tại chưa chọn câu nào nên selected_indices rỗng
return summary, avg_confidence, fig
# ====================== MAIN FUNCTION ======================
def ATS(
text: str,
segmentation_method: str = None,
model: str = None,
reference_summary: str = None
) -> str:
if not text or len(text.split()) < 5:
return "Văn bản quá ngắn hoặc rỗng, vui lòng nhập thêm nội dung.", "⚠️ Không có dữ liệu", None
try:
if segmentation_method == "Sentence-based Preprocessing":
prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='sentence')
else:
prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu')
# Step 2: Chọn model
if model == "Baseline Model":
result, avg_conf, plot_fig = model_baseline(prepro_dict)
elif model == "Baseline Model with Extractive":
result, avg_conf, plot_fig = model_baseline_extractive(prepro_dict)
else:
result, avg_conf, plot_fig = model_extractive_abstract(prepro_dict)
origin_words = len(text.split())
sum_words = len(result.split())
comp_ratio = (sum_words / origin_words * 100) if origin_words > 0 else 0
conf_str = f"**{avg_conf * 100:.2f}%**" if avg_conf is not None else "*N/A (Không tính Saliency Score)*"
metrics_md = f"""
### 📊 Thống kê kết quả
- **Tổng số từ văn bản gốc:** {origin_words} từ
- **Số từ bản tóm tắt:** {sum_words} từ
- **Tỉ lệ nén:** {comp_ratio:.1f}% *(Giữ lại {comp_ratio:.1f}% dung lượng gốc)*
- **Độ tin cậy trung bình (Confidence):** {conf_str}
"""
if plot_fig is None:
plot_fig = go.Figure()
plot_fig.update_layout(title="Mô hình Baseline không hỗ trợ Saliency Plot.", template="plotly_white")
return result, metrics_md, plot_fig
except Exception as e:
return f"Đã xảy ra lỗi: {str(e)}", "⚠️ Lỗi hệ thống", None
# ====================== GRADIO INTERFACE ======================
with gr.Blocks(
title="Automated Text Summarization System",
) as demo:
gr.Markdown(
"""
# 🚀 Automated Text Summarization System
**Input text → Select method & model → Get results instantly**
"""
)
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="📝 Text to Summarize",
placeholder="Paste your long text here (up to several thousand words)...",
lines=15,
max_lines=30,
)
with gr.Row():
btn_summary = gr.Button(
"🔍 Summarize Now",
variant="primary",
size="large"
)
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Settings")
method = gr.Radio(
choices=[
"Sentence-based Preprocessing",
"Semantic-based Preprocessing (EDU)"
],
value="Sentence-based Preprocessing",
label="Preprocessing Method",
info="Choose how to clean the text before summarization"
)
model = gr.Radio(
choices=[
"Baseline Model",
"Baseline Model with Extractive",
"Extractive and Abstractive Model"
],
value="Baseline Model",
label="Summarization Model",
info="Select the model you want to use"
)
output_metrics = gr.Markdown("### 📊 Thống kê kết quả\n*Chờ xử lý...*")
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
output_text = gr.Textbox(
label="📄 Summary Result",
lines=10,
placeholder="The result will appear here...",
)
with gr.Column(scale=1):
output_plot = gr.Plot(label="📈 Saliency Score Plot")
# Connect button click
btn_summary.click(
fn=ATS,
inputs=[input_text, method, model],
outputs=[output_text, output_metrics, output_plot]
)
# Examples
gr.Examples(
examples=[
["Hanoi is the capital of Vietnam. The city has a history of over 1000 years. The population is about 8 million people. It is the political, economic, and cultural center of the country."],
["Artificial Intelligence (AI) is changing the world. Many large companies are investing heavily in this field. Gradio helps developers quickly build web interfaces for AI models."],
],
inputs=input_text,
label="📌 Test Examples"
)
gr.Markdown(
"""
---
💡 **User Guide**:
1. Paste the text to be summarized into the left box.
2. Select the preprocessing method and model.
3. Click **Summarize Now**.
The result will appear instantly (currently using mock data for demo).
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(),
css="""
.gradio-container {max-width: 1200px; margin: auto;}
.title {text-align: center; margin-bottom: 10px;}
""",
share=True,
ssr_mode=False
)