Spaces:
Sleeping
Sleeping
Commit ·
fdcfd15
1
Parent(s): 846a1b0
Update SRC basline and baseline_extractive
Browse files- app.py +155 -117
- src/model/baseline_extractive_model.py +45 -0
- src/model/baseline_model.py +87 -0
- src/preprocessing/edu_sentences.py +178 -0
- src/utils/get_model.py +35 -0
app.py
CHANGED
|
@@ -1,94 +1,132 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
""
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
""
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
) -> str:
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
# Bước 1: Tiền xử lý
|
| 73 |
-
if phuong_phap_tien_xu_ly == "Tiền xử lý theo câu":
|
| 74 |
-
van_ban_da_xu_ly = tien_xu_ly_theo_cau(van_ban)
|
| 75 |
else:
|
| 76 |
-
|
| 77 |
|
| 78 |
-
#
|
| 79 |
-
if
|
| 80 |
-
|
| 81 |
-
elif
|
| 82 |
-
|
| 83 |
-
else:
|
| 84 |
-
|
| 85 |
|
| 86 |
-
return
|
| 87 |
|
| 88 |
|
| 89 |
-
# ======================
|
| 90 |
with gr.Blocks(
|
| 91 |
-
title="
|
| 92 |
theme=gr.themes.Soft(),
|
| 93 |
css="""
|
| 94 |
.gradio-container {max-width: 1200px; margin: auto;}
|
|
@@ -97,88 +135,88 @@ with gr.Blocks(
|
|
| 97 |
) as demo:
|
| 98 |
gr.Markdown(
|
| 99 |
"""
|
| 100 |
-
# 🚀
|
| 101 |
-
**
|
| 102 |
"""
|
| 103 |
)
|
| 104 |
|
| 105 |
with gr.Row():
|
| 106 |
with gr.Column(scale=3):
|
| 107 |
input_text = gr.Textbox(
|
| 108 |
-
label="📝
|
| 109 |
-
placeholder="
|
| 110 |
lines=12,
|
| 111 |
max_lines=30,
|
| 112 |
show_copy_button=True
|
| 113 |
)
|
| 114 |
|
| 115 |
with gr.Column(scale=1):
|
| 116 |
-
gr.Markdown("### ⚙️
|
| 117 |
|
| 118 |
-
|
| 119 |
choices=[
|
| 120 |
-
"
|
| 121 |
-
"
|
| 122 |
],
|
| 123 |
-
value="
|
| 124 |
-
label="
|
| 125 |
-
info="
|
| 126 |
)
|
| 127 |
|
| 128 |
-
|
| 129 |
choices=[
|
| 130 |
-
"
|
| 131 |
-
"
|
| 132 |
-
"
|
| 133 |
],
|
| 134 |
-
value="
|
| 135 |
-
label="
|
| 136 |
-
info="
|
| 137 |
)
|
| 138 |
|
| 139 |
with gr.Row():
|
| 140 |
btn_tom_tat = gr.Button(
|
| 141 |
-
"🔍
|
| 142 |
variant="primary",
|
| 143 |
size="large"
|
| 144 |
)
|
| 145 |
|
| 146 |
output_text = gr.Textbox(
|
| 147 |
-
label="📄
|
| 148 |
lines=10,
|
| 149 |
-
placeholder="
|
| 150 |
show_copy_button=True
|
| 151 |
)
|
| 152 |
|
| 153 |
-
#
|
| 154 |
btn_tom_tat.click(
|
| 155 |
-
fn=
|
| 156 |
-
inputs=[input_text,
|
| 157 |
outputs=output_text
|
| 158 |
)
|
| 159 |
|
| 160 |
-
#
|
| 161 |
gr.Examples(
|
| 162 |
examples=[
|
| 163 |
-
["
|
| 164 |
-
["
|
| 165 |
],
|
| 166 |
inputs=input_text,
|
| 167 |
-
label="📌
|
| 168 |
)
|
| 169 |
|
| 170 |
gr.Markdown(
|
| 171 |
"""
|
| 172 |
---
|
| 173 |
-
💡 **
|
| 174 |
-
1.
|
| 175 |
-
2.
|
| 176 |
-
3.
|
| 177 |
-
|
| 178 |
"""
|
| 179 |
)
|
| 180 |
|
| 181 |
|
| 182 |
-
#
|
| 183 |
if __name__ == "__main__":
|
| 184 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
| 4 |
+
import re
|
| 5 |
+
import numpy as np
|
| 6 |
+
import networkx as nx
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
from src.utils.get_model import get_summarizer
|
| 9 |
+
from src.preprocessing.edu_sentences import preprocess_external_text
|
| 10 |
+
from src.utils.get_model import get_extractive_model
|
| 11 |
+
from src.model.baseline_extractive_model import get_trigrams
|
| 12 |
+
|
| 13 |
+
REPO_ID_baseline_model = "Reality8081/bart-base"
|
| 14 |
+
REPO_ID_baseline_model_edu = "Reality8081/bart-base-edu"
|
| 15 |
+
REPO_ID_baseline_extractive_model = "Reality8081/bart-extractive"
|
| 16 |
+
REPO_ID_baseline_extractive_model_edu = "Reality8081/bart-extractive-edu"
|
| 17 |
+
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
def model_baseline(prepro_dict: Dict) -> str:
|
| 21 |
+
text_to_summarize = prepro_dict.get("article", "")
|
| 22 |
+
|
| 23 |
+
if not text_to_summarize.strip():
|
| 24 |
+
return "Lỗi: Không tìm thấy văn bản để tóm tắt."
|
| 25 |
+
|
| 26 |
+
segmentation_method = prepro_dict.get("segmentation_method")
|
| 27 |
+
if segmentation_method == "edu":
|
| 28 |
+
repo_id = REPO_ID_baseline_model_edu
|
| 29 |
+
else:
|
| 30 |
+
repo_id = REPO_ID_baseline_model
|
| 31 |
+
summarizer = get_summarizer(repo_id)
|
| 32 |
+
summary = summarizer.summarize(text_to_summarize)
|
| 33 |
+
return summary
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> str:
|
| 38 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
| 39 |
+
segments = prepro_dict["segments"]
|
| 40 |
+
if not segments:
|
| 41 |
+
return "Không thể phân tách văn bản thành các câu/EDU."
|
| 42 |
+
input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
|
| 43 |
+
attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
|
| 44 |
+
segmentation_method = prepro_dict.get("segmentation_method")
|
| 45 |
+
if segmentation_method == "edu":
|
| 46 |
+
repo_id = REPO_ID_baseline_extractive_model_edu
|
| 47 |
+
else:
|
| 48 |
+
repo_id = REPO_ID_baseline_extractive_model
|
| 49 |
+
model = get_summarizer(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 52 |
+
# Sử dụng Sigmoid đưa logit về khoảng (0, 1) để lấy xác suất
|
| 53 |
+
probs = torch.sigmoid(outputs['logits']).squeeze(0).cpu().numpy()
|
| 54 |
+
|
| 55 |
+
segment_scores = []
|
| 56 |
+
current_idx = 1 # Bỏ qua token đặc biệt <s> ở đầu chuỗi
|
| 57 |
+
|
| 58 |
+
for seg in segments:
|
| 59 |
+
# Lấy số lượng token của đoạn hiện tại
|
| 60 |
+
seg_len = len(tokenizer.encode(seg, add_special_tokens=False))
|
| 61 |
+
end_idx = min(current_idx + seg_len, len(probs))
|
| 62 |
+
|
| 63 |
+
if current_idx < len(probs):
|
| 64 |
+
# Điểm của câu là trung bình cộng xác suất các token bên trong nó
|
| 65 |
+
seg_score = np.mean(probs[current_idx:end_idx])
|
| 66 |
+
else:
|
| 67 |
+
seg_score = 0.0
|
| 68 |
+
|
| 69 |
+
segment_scores.append(seg_score)
|
| 70 |
+
current_idx += seg_len
|
| 71 |
+
|
| 72 |
+
ranked_indices = np.argsort(segment_scores)[::-1] # Sắp xếp giảm dần theo điểm
|
| 73 |
+
|
| 74 |
+
selected_indices = []
|
| 75 |
+
selected_trigrams = set()
|
| 76 |
+
|
| 77 |
+
for idx in ranked_indices:
|
| 78 |
+
candidate_seg = segments[idx]
|
| 79 |
+
candidate_trigrams = get_trigrams(candidate_seg)
|
| 80 |
+
|
| 81 |
+
# Chỉ chọn nếu segment này không trùng lặp các cụm từ (trigram) so với những đoạn đã chọn
|
| 82 |
+
if not candidate_trigrams.intersection(selected_trigrams):
|
| 83 |
+
selected_indices.append(idx)
|
| 84 |
+
selected_trigrams.update(candidate_trigrams)
|
| 85 |
+
|
| 86 |
+
# Dừng lại nếu đã gom đủ số câu theo yêu cầu
|
| 87 |
+
if len(selected_indices) == top_n:
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
# Sắp xếp lại thứ tự index xuất hiện của các câu trong văn bản gốc để tóm tắt được mạch lạc
|
| 91 |
+
selected_indices = sorted(selected_indices)
|
| 92 |
+
|
| 93 |
+
# Lắp ráp kết quả
|
| 94 |
+
extractive_summary = " ".join([segments[i] for i in selected_indices if i < len(segments)])
|
| 95 |
+
return extractive_summary
|
| 96 |
+
|
| 97 |
+
def model_extractive_abstract(prepro_dict: Dict) -> str:
|
| 98 |
+
|
| 99 |
+
return prepro_dict
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ====================== MAIN FUNCTION ======================
|
| 103 |
+
def ATS(
|
| 104 |
+
text: str,
|
| 105 |
+
segmentation_method: str = None,
|
| 106 |
+
model: str = None,
|
| 107 |
+
reference_summary: str = None
|
| 108 |
) -> str:
|
| 109 |
+
"""Main workflow: Raw Text → Preprocessing → Model"""
|
| 110 |
+
# Step 1: Preprocessing
|
| 111 |
+
if segmentation_method == "Sentence-based Preprocessing":
|
| 112 |
+
prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='sentence')
|
|
|
|
|
|
|
|
|
|
| 113 |
else:
|
| 114 |
+
prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu')
|
| 115 |
|
| 116 |
+
# Step 2: Chọn model
|
| 117 |
+
if model == "Baseline Model: TextRank + Vanilla BART":
|
| 118 |
+
result = model_baseline(prepro_dict)
|
| 119 |
+
elif model == "Baseline Model with Extractive":
|
| 120 |
+
result = model_baseline_extractive(prepro_dict)
|
| 121 |
+
else:
|
| 122 |
+
result = model_extractive_abstract(prepro_dict)
|
| 123 |
|
| 124 |
+
return result
|
| 125 |
|
| 126 |
|
| 127 |
+
# ====================== GRADIO INTERFACE ======================
|
| 128 |
with gr.Blocks(
|
| 129 |
+
title="Automated Text Summarization System",
|
| 130 |
theme=gr.themes.Soft(),
|
| 131 |
css="""
|
| 132 |
.gradio-container {max-width: 1200px; margin: auto;}
|
|
|
|
| 135 |
) as demo:
|
| 136 |
gr.Markdown(
|
| 137 |
"""
|
| 138 |
+
# 🚀 Automated Text Summarization System
|
| 139 |
+
**Input text → Select method & model → Get results instantly**
|
| 140 |
"""
|
| 141 |
)
|
| 142 |
|
| 143 |
with gr.Row():
|
| 144 |
with gr.Column(scale=3):
|
| 145 |
input_text = gr.Textbox(
|
| 146 |
+
label="📝 Text to Summarize",
|
| 147 |
+
placeholder="Paste your long text here (up to several thousand words)...",
|
| 148 |
lines=12,
|
| 149 |
max_lines=30,
|
| 150 |
show_copy_button=True
|
| 151 |
)
|
| 152 |
|
| 153 |
with gr.Column(scale=1):
|
| 154 |
+
gr.Markdown("### ⚙️ Settings")
|
| 155 |
|
| 156 |
+
method = gr.Radio(
|
| 157 |
choices=[
|
| 158 |
+
"Sentence-based Preprocessing",
|
| 159 |
+
"Semantic-based Preprocessing (EDU)"
|
| 160 |
],
|
| 161 |
+
value="Sentence-based Preprocessing",
|
| 162 |
+
label="Preprocessing Method",
|
| 163 |
+
info="Choose how to clean the text before summarization"
|
| 164 |
)
|
| 165 |
|
| 166 |
+
model = gr.Radio(
|
| 167 |
choices=[
|
| 168 |
+
"Baseline Model",
|
| 169 |
+
"Baseline Model with Extractive",
|
| 170 |
+
"Extractive and Abstractive Model"
|
| 171 |
],
|
| 172 |
+
value="Baseline Model",
|
| 173 |
+
label="Summarization Model",
|
| 174 |
+
info="Select the model you want to use"
|
| 175 |
)
|
| 176 |
|
| 177 |
with gr.Row():
|
| 178 |
btn_tom_tat = gr.Button(
|
| 179 |
+
"🔍 Summarize Now",
|
| 180 |
variant="primary",
|
| 181 |
size="large"
|
| 182 |
)
|
| 183 |
|
| 184 |
output_text = gr.Textbox(
|
| 185 |
+
label="📄 Summary Result",
|
| 186 |
lines=10,
|
| 187 |
+
placeholder="The result will appear here...",
|
| 188 |
show_copy_button=True
|
| 189 |
)
|
| 190 |
|
| 191 |
+
# Connect button click
|
| 192 |
btn_tom_tat.click(
|
| 193 |
+
fn=ATS,
|
| 194 |
+
inputs=[input_text, method, model],
|
| 195 |
outputs=output_text
|
| 196 |
)
|
| 197 |
|
| 198 |
+
# Examples
|
| 199 |
gr.Examples(
|
| 200 |
examples=[
|
| 201 |
+
["Hanoi is the capital of Vietnam. The city has a history of over 1000 years. The population is about 8 million people. It is the political, economic, and cultural center of the country."],
|
| 202 |
+
["Artificial Intelligence (AI) is changing the world. Many large companies are investing heavily in this field. Gradio helps developers quickly build web interfaces for AI models."],
|
| 203 |
],
|
| 204 |
inputs=input_text,
|
| 205 |
+
label="📌 Test Examples"
|
| 206 |
)
|
| 207 |
|
| 208 |
gr.Markdown(
|
| 209 |
"""
|
| 210 |
---
|
| 211 |
+
💡 **User Guide**:
|
| 212 |
+
1. Paste the text to be summarized into the left box.
|
| 213 |
+
2. Select the preprocessing method and model.
|
| 214 |
+
3. Click **Summarize Now**.
|
| 215 |
+
The result will appear instantly (currently using mock data for demo).
|
| 216 |
"""
|
| 217 |
)
|
| 218 |
|
| 219 |
|
| 220 |
+
# Launch the app
|
| 221 |
if __name__ == "__main__":
|
| 222 |
demo.launch()
|
src/model/baseline_extractive_model.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import BartModel, BartTokenizer
|
| 5 |
+
|
| 6 |
+
class BartExtractiveSummarizer(nn.Module):
|
| 7 |
+
def __init__(self, model_name="facebook/bart-large"):
|
| 8 |
+
super(BartExtractiveSummarizer, self).__init__()
|
| 9 |
+
self.encoder = BartModel.from_pretrained(model_name).encoder
|
| 10 |
+
hidden_size = self.encoder.config.hidden_size
|
| 11 |
+
self.classifier = nn.Linear(hidden_size, 1)
|
| 12 |
+
|
| 13 |
+
def forward(self, input_ids, attention_mask, saliency_mask=None, **kwargs):
|
| 14 |
+
encoder_outputs = self.encoder(
|
| 15 |
+
input_ids=input_ids,
|
| 16 |
+
attention_mask=attention_mask
|
| 17 |
+
)
|
| 18 |
+
hidden_states = encoder_outputs.last_hidden_state
|
| 19 |
+
logits = self.classifier(hidden_states).squeeze(-1)
|
| 20 |
+
|
| 21 |
+
loss = None
|
| 22 |
+
if saliency_mask is not None:
|
| 23 |
+
active_loss = attention_mask.view(-1) == 1
|
| 24 |
+
active_logits = logits.view(-1)[active_loss]
|
| 25 |
+
active_labels = saliency_mask.view(-1)[active_loss].float()
|
| 26 |
+
|
| 27 |
+
# --- TỐI ƯU: TỰ ĐỘNG TÍNH CLASS WEIGHT CHO TỪNG BATCH ---
|
| 28 |
+
num_pos = active_labels.sum()
|
| 29 |
+
num_neg = active_labels.size(0) - num_pos
|
| 30 |
+
|
| 31 |
+
if num_pos > 0:
|
| 32 |
+
weight = torch.clamp(num_neg / num_pos, max=10.0)
|
| 33 |
+
else:
|
| 34 |
+
weight = torch.tensor(1.0).to(logits.device)
|
| 35 |
+
|
| 36 |
+
loss_fct = nn.BCEWithLogitsLoss(pos_weight=weight)
|
| 37 |
+
loss = loss_fct(active_logits, active_labels)
|
| 38 |
+
|
| 39 |
+
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
|
| 40 |
+
|
| 41 |
+
def get_trigrams(text: str):
|
| 42 |
+
"""Tạo tập hợp các cụm 3 từ liên tiếp từ một đoạn văn bản (Trigram Blocking)"""
|
| 43 |
+
words = text.lower().split()
|
| 44 |
+
return set(tuple(words[i:i+3]) for i in range(len(words)-2))
|
| 45 |
+
|
src/model/baseline_model.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import networkx as nx
|
| 2 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
| 6 |
+
|
| 7 |
+
def textrank_summarize(sentences, top_n=3):
|
| 8 |
+
"""
|
| 9 |
+
Trích xuất câu quan trọng bằng TextRank + TF-IDF.
|
| 10 |
+
Đầu vào 'sentences' là một list các câu (hoặc EDUs) trong một văn bản.
|
| 11 |
+
"""
|
| 12 |
+
# Xử lý trường hợp bài báo quá ngắn
|
| 13 |
+
if len(sentences) <= top_n:
|
| 14 |
+
return " ".join(sentences)
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
# Bước 1: Khởi tạo TfidfVectorizer và fit_transform tập sentences của 1 bài báo
|
| 18 |
+
vectorizer = TfidfVectorizer(stop_words='english')
|
| 19 |
+
tfidf_matrix = vectorizer.fit_transform(sentences)
|
| 20 |
+
|
| 21 |
+
# Bước 2: Tính ma trận tương đồng Cosine
|
| 22 |
+
similarity_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix)
|
| 23 |
+
|
| 24 |
+
# Bước 3: Đưa ma trận vào networkx tạo đồ thị và tính PageRank
|
| 25 |
+
nx_graph = nx.from_numpy_array(similarity_matrix)
|
| 26 |
+
scores = nx.pagerank(nx_graph)
|
| 27 |
+
|
| 28 |
+
# Bước 4: Sắp xếp điểm số và chọn top_n câu
|
| 29 |
+
ranked_sentences = sorted(((scores[i], s) for i, s in enumerate(sentences)), reverse=True)
|
| 30 |
+
|
| 31 |
+
# Giữ đúng thứ tự xuất hiện của câu trong văn bản gốc để dễ đọc
|
| 32 |
+
top_sentences_indices = sorted([sentences.index(ranked_sentences[i][1]) for i in range(top_n)])
|
| 33 |
+
summary = " ".join([sentences[i] for i in top_sentences_indices])
|
| 34 |
+
return summary
|
| 35 |
+
|
| 36 |
+
except Exception as e:
|
| 37 |
+
# Fallback về Lead-N nếu đồ thị lỗi (do câu rỗng hoặc không có từ vựng)
|
| 38 |
+
return " ".join(sentences[:top_n])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BartSummarizer:
|
| 42 |
+
def __init__(self, model_path="facebook/bart-base"):
|
| 43 |
+
"""
|
| 44 |
+
Khởi tạo mô hình và tokenizer.
|
| 45 |
+
model_path có thể là repo trên Hugging Face hoặc đường dẫn local chứa weights.
|
| 46 |
+
"""
|
| 47 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 48 |
+
print(f"Loading BART model from '{model_path}' onto {self.device}...")
|
| 49 |
+
|
| 50 |
+
self.tokenizer = BartTokenizer.from_pretrained(model_path)
|
| 51 |
+
self.model = BartForConditionalGeneration.from_pretrained(model_path)
|
| 52 |
+
self.model.to(self.device)
|
| 53 |
+
self.model.eval() # Chuyển sang chế độ inference ngay từ đầu
|
| 54 |
+
|
| 55 |
+
def summarize(self, text, max_input_length=512, max_output_length=128, min_output_length=30):
|
| 56 |
+
"""
|
| 57 |
+
Hàm sinh tóm tắt cho một đoạn văn bản đầu vào.
|
| 58 |
+
"""
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
# Cắt ngắn đầu vào để chống quá tải GPU, đồng bộ với lúc train
|
| 61 |
+
inputs = self.tokenizer(
|
| 62 |
+
text,
|
| 63 |
+
max_length=max_input_length,
|
| 64 |
+
truncation=True,
|
| 65 |
+
padding=True,
|
| 66 |
+
return_tensors="pt"
|
| 67 |
+
).to(self.device)
|
| 68 |
+
|
| 69 |
+
# Sinh văn bản tóm tắt
|
| 70 |
+
summary_ids = self.model.generate(
|
| 71 |
+
input_ids=inputs["input_ids"],
|
| 72 |
+
attention_mask=inputs["attention_mask"],
|
| 73 |
+
max_length=max_output_length,
|
| 74 |
+
min_length=min_output_length,
|
| 75 |
+
num_beams=4,
|
| 76 |
+
length_penalty=2.0, # Ưu tiên sinh câu trọn vẹn
|
| 77 |
+
no_repeat_ngram_size=3, # Chống ảo giác, lặp từ
|
| 78 |
+
early_stopping=True
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Decode kết quả về dạng text
|
| 82 |
+
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 83 |
+
return summary
|
| 84 |
+
|
| 85 |
+
# Cách gọi trong app:
|
| 86 |
+
# summarizer = BartSummarizer("duong_dan_model_cua_ban_tren_huggingface")
|
| 87 |
+
# result = summarizer.summarize("Đoạn văn bản cần tóm tắt...")
|
src/preprocessing/edu_sentences.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========================== preprocessing_utils.py ==========================
|
| 2 |
+
import re
|
| 3 |
+
import nltk
|
| 4 |
+
import numpy as np
|
| 5 |
+
import spacy
|
| 6 |
+
from transformers import BartTokenizer
|
| 7 |
+
from rouge_score import rouge_scorer
|
| 8 |
+
from typing import List, Dict, Optional, Union
|
| 9 |
+
|
| 10 |
+
nltk.download('punkt', quiet=True)
|
| 11 |
+
nltk.download('punkt_tab', quiet=True)
|
| 12 |
+
|
| 13 |
+
# Tải SpaCy một lần duy nhất (nhẹ, disable các thành phần không cần)
|
| 14 |
+
nlp = spacy.load("en_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler", "tok2vec"])
|
| 15 |
+
|
| 16 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
| 17 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
| 18 |
+
|
| 19 |
+
def clean_text(text: str) -> str:
|
| 20 |
+
"""Làm sạch văn bản (dùng chung cho mọi pipeline)"""
|
| 21 |
+
if not isinstance(text, str):
|
| 22 |
+
return ""
|
| 23 |
+
# Xóa URL, email, twitter handle
|
| 24 |
+
text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
|
| 25 |
+
text = re.sub(r'\S+@\S+', '', text)
|
| 26 |
+
text = re.sub(r'@[A-Za-z0-9_]+', '', text)
|
| 27 |
+
# Giữ lại chữ, số, dấu câu cơ bản
|
| 28 |
+
text = re.sub(r'[^\w\s.,;:\'"-?!]', '', text)
|
| 29 |
+
# Chuẩn hóa khoảng trắng
|
| 30 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 31 |
+
return text
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def segment_text(text: str, method: str = 'sentence') -> tuple[List[str], str]:
|
| 35 |
+
"""
|
| 36 |
+
Phân tách văn bản theo phương pháp được chọn.
|
| 37 |
+
Trả về: (list_segments, cleaned_text)
|
| 38 |
+
"""
|
| 39 |
+
cleaned = clean_text(text)
|
| 40 |
+
|
| 41 |
+
if method == 'sentence':
|
| 42 |
+
segments = nltk.sent_tokenize(cleaned)
|
| 43 |
+
return segments, cleaned
|
| 44 |
+
|
| 45 |
+
elif method == 'edu':
|
| 46 |
+
# Giống hệt logic notebook EDU (tách câu trước → EDU bằng SpaCy)
|
| 47 |
+
sentences = nltk.sent_tokenize(cleaned)
|
| 48 |
+
processed_docs = list(nlp.pipe(sentences, batch_size=500))
|
| 49 |
+
|
| 50 |
+
all_edus = []
|
| 51 |
+
for doc in processed_docs:
|
| 52 |
+
temp_edus, current_segment = [], []
|
| 53 |
+
for token in doc:
|
| 54 |
+
current_segment.append(token.text_with_ws)
|
| 55 |
+
if (token.pos_ in ["SCONJ", "CCONJ"] or token.text in [",", ";"]) and len(current_segment) > 3:
|
| 56 |
+
temp_edus.append("".join(current_segment).strip())
|
| 57 |
+
current_segment = []
|
| 58 |
+
if current_segment:
|
| 59 |
+
temp_edus.append("".join(current_segment).strip())
|
| 60 |
+
all_edus.extend(temp_edus if temp_edus else [doc.text])
|
| 61 |
+
|
| 62 |
+
return all_edus, cleaned
|
| 63 |
+
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError("method phải là 'sentence' hoặc 'edu'")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def greedy_rouge_selection(segments: List[str], reference_summary: str, top_k: int = 3) -> List[int]:
|
| 69 |
+
"""Thuật toán Greedy ROUGE (dùng chung)"""
|
| 70 |
+
selected_indices = []
|
| 71 |
+
best_rouge = 0.0
|
| 72 |
+
if not segments:
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
for _ in range(min(top_k, len(segments))):
|
| 76 |
+
best_idx = -1
|
| 77 |
+
current_best = best_rouge
|
| 78 |
+
|
| 79 |
+
for i, seg in enumerate(segments):
|
| 80 |
+
if i in selected_indices:
|
| 81 |
+
continue
|
| 82 |
+
candidate = " ".join([segments[j] for j in selected_indices] + [seg])
|
| 83 |
+
scores = scorer.score(reference_summary, candidate)
|
| 84 |
+
avg_f = (scores['rouge1'].fmeasure +
|
| 85 |
+
scores['rouge2'].fmeasure +
|
| 86 |
+
scores['rougeL'].fmeasure) / 3.0
|
| 87 |
+
|
| 88 |
+
if avg_f > current_best:
|
| 89 |
+
current_best = avg_f
|
| 90 |
+
best_idx = i
|
| 91 |
+
|
| 92 |
+
if best_idx != -1:
|
| 93 |
+
selected_indices.append(best_idx)
|
| 94 |
+
best_rouge = current_best
|
| 95 |
+
else:
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
return [1 if i in selected_indices else 0 for i in range(len(segments))]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def create_saliency_mask(input_ids: List[int], segments: List[str],
|
| 102 |
+
ext_labels: List[int], tokenizer) -> List[int]:
|
| 103 |
+
"""Tạo Saliency Mask từ segment-level xuống token-level"""
|
| 104 |
+
mask = np.zeros(len(input_ids), dtype=int)
|
| 105 |
+
mask[0] = 1
|
| 106 |
+
if input_ids and input_ids[-1] == tokenizer.eos_token_id:
|
| 107 |
+
mask[-1] = 1
|
| 108 |
+
|
| 109 |
+
current_idx = 1
|
| 110 |
+
for seg_idx, segment in enumerate(segments):
|
| 111 |
+
if current_idx >= len(input_ids) - 1:
|
| 112 |
+
break
|
| 113 |
+
seg_tokens = tokenizer.encode(segment, add_special_tokens=False)
|
| 114 |
+
token_len = len(seg_tokens)
|
| 115 |
+
|
| 116 |
+
if seg_idx < len(ext_labels) and ext_labels[seg_idx] == 1:
|
| 117 |
+
end_idx = min(current_idx + token_len, len(input_ids) - 1)
|
| 118 |
+
mask[current_idx:end_idx] = 1
|
| 119 |
+
|
| 120 |
+
current_idx += token_len
|
| 121 |
+
|
| 122 |
+
return mask.tolist()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def preprocess_external_text(
|
| 126 |
+
text: str,
|
| 127 |
+
reference_summary: Optional[str] = None,
|
| 128 |
+
segmentation_method: str = 'sentence',
|
| 129 |
+
top_k: int = 3,
|
| 130 |
+
max_length: int = 1024
|
| 131 |
+
) -> Dict:
|
| 132 |
+
|
| 133 |
+
segments, cleaned_article = segment_text(text, method=segmentation_method)
|
| 134 |
+
|
| 135 |
+
inputs = tokenizer(cleaned_article, max_length=max_length, truncation=True, padding=False)
|
| 136 |
+
|
| 137 |
+
result = {
|
| 138 |
+
"article": cleaned_article,
|
| 139 |
+
"segments": segments, # ← list câu hoặc list EDU
|
| 140 |
+
"segmentation_method": segmentation_method,
|
| 141 |
+
"input_ids": inputs["input_ids"],
|
| 142 |
+
"attention_mask": inputs["attention_mask"],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
# Nếu có tóm tắt tham chiếu → tính nhãn extractive
|
| 146 |
+
if reference_summary is not None:
|
| 147 |
+
ref_clean = clean_text(reference_summary)
|
| 148 |
+
extractive_labels = greedy_rouge_selection(segments, ref_clean, top_k=top_k)
|
| 149 |
+
saliency_mask = create_saliency_mask(inputs["input_ids"], segments, extractive_labels, tokenizer)
|
| 150 |
+
|
| 151 |
+
targets = tokenizer(ref_clean, max_length=128, truncation=True, padding=False)
|
| 152 |
+
|
| 153 |
+
result.update({
|
| 154 |
+
"extractive_labels": extractive_labels,
|
| 155 |
+
"saliency_mask": saliency_mask,
|
| 156 |
+
"labels": targets["input_ids"], # cho phần Abstractive
|
| 157 |
+
"reference_summary": ref_clean
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
return result
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def preprocess_batch(
|
| 164 |
+
texts: List[str],
|
| 165 |
+
reference_summaries: Optional[List[str]] = None,
|
| 166 |
+
segmentation_method: str = 'sentence',
|
| 167 |
+
top_k: int = 3
|
| 168 |
+
) -> List[Dict]:
|
| 169 |
+
"""Xử lý nhiều văn bản cùng lúc (dùng cho demo batch)"""
|
| 170 |
+
if reference_summaries is None:
|
| 171 |
+
reference_summaries = [None] * len(texts)
|
| 172 |
+
if len(reference_summaries) != len(texts):
|
| 173 |
+
raise ValueError("Số lượng reference_summaries phải bằng số lượng texts")
|
| 174 |
+
|
| 175 |
+
return [
|
| 176 |
+
preprocess_external_text(txt, ref, segmentation_method, top_k)
|
| 177 |
+
for txt, ref in zip(texts, reference_summaries)
|
| 178 |
+
]
|
src/utils/get_model.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.model.baseline_model import BartSummarizer
|
| 2 |
+
from src.model.baseline_extractive_model import BartExtractiveSummarizer
|
| 3 |
+
loaded_summarizers = {}
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_summarizer(repo_id: str):
|
| 11 |
+
|
| 12 |
+
if repo_id not in loaded_summarizers:
|
| 13 |
+
loaded_summarizers[repo_id] = BartSummarizer(model_path=repo_id)
|
| 14 |
+
return loaded_summarizers[repo_id]
|
| 15 |
+
|
| 16 |
+
def get_extractive_model(repo_id: str, base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"):
|
| 17 |
+
"""Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub"""
|
| 18 |
+
if repo_id not in loaded_summarizers:
|
| 19 |
+
print(f"Đang tải mô hình Extractive từ repo: {repo_id}...")
|
| 20 |
+
|
| 21 |
+
# Khởi tạo khung kiến trúc trống
|
| 22 |
+
model = BartExtractiveSummarizer(model_name=base_model_name)
|
| 23 |
+
|
| 24 |
+
# Sử dụng hf_hub_download để kéo file trọng số về local cache
|
| 25 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="model_state.bin")
|
| 26 |
+
|
| 27 |
+
# Load trọng số vào model
|
| 28 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 29 |
+
|
| 30 |
+
model.to(device)
|
| 31 |
+
model.eval() # Chuyển mô hình sang chế độ inference
|
| 32 |
+
|
| 33 |
+
loaded_summarizers[repo_id] = model
|
| 34 |
+
|
| 35 |
+
return loaded_summarizers[repo_id]
|