Reality8081 commited on
Commit
fdcfd15
·
1 Parent(s): 846a1b0

Update SRC basline and baseline_extractive

Browse files
app.py CHANGED
@@ -1,94 +1,132 @@
1
  import gradio as gr
2
-
3
- # ====================== TIỀN XỬ LÝ ======================
4
- def tien_xu_ly_theo_cau(van_ban: str) -> str:
5
- """
6
- Hàm tiền xử lý theo câu (Skeleton function).
7
- Bạn thể thay thế bằng code thật (ví dụ: tách câu bằng NLTK, spaCy, hoặc regex).
8
- Hiện tại mock: giữ nguyên văn bản gốc để demo.
9
- """
10
- # TODO: Thay code thật ở đây
11
- # dụ:
12
- # import nltk
13
- # nltk.download('punkt', quiet=True)
14
- # cau = nltk.sent_tokenize(van_ban)
15
- # return " ".join(cau[:10]) # giữ 10 câu đầu làm ví dụ
16
- return van_ban
17
-
18
-
19
- def tien_xu_ly_theo_ngu_nghia(van_ban: str) -> str:
20
- """
21
- Hàm tiền xử lý theo ngữ nghĩa (Skeleton function).
22
- Bạn có thể thay thế bằng code thật (ví dụ: cleaning, embedding, loại bỏ stopword, v.v.).
23
- Hiện tại là mock: giữ nguyên văn bản gốc để demo.
24
- """
25
- # TODO: Thay code thật ở đây
26
- # dụ: dùng transformers hoặc spaCy để embedding hoặc semantic cleaning
27
- return van_ban
28
-
29
-
30
- # ====================== MÔ HÌNH TÓM TẮT ======================
31
- def mo_hinh_baseline(van_ban_da_xu_ly: str) -> str:
32
- """
33
- hình tóm tắt Baseline - PLACEHOLDER.
34
- # TODO: Thay thế bằng mô hình thật (ví dụ: transformers pipeline)
35
- # Ví dụ code thật:
36
- # from transformers import pipeline
37
- # summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
38
- # return summarizer(van_ban_da_xu_ly, max_length=130, min_length=30, do_sample=False)[0]['summary_text']
39
- """
40
- # Mock để demo
41
- return f"🔹 TÓM TẮT BASELINE:\n{van_ban_da_xu_ly[:250]}..."
42
-
43
-
44
- def mo_hinh_baseline_extractive(van_ban_da_xu_ly: str) -> str:
45
- """
46
- hình Baseline + Extractive - PLACEHOLDER.
47
- # TODO: Thay thế bằng code thật (extractive trước rồi abstractive)
48
- """
49
- # Mock để demo
50
- return f"🔹 TÓM TẮT BASELINE + EXTRACTIVE:\n{van_ban_da_xu_ly[:220]}..."
51
-
52
-
53
- def mo_hinh_extractive_abstract(van_ban_da_xu_ly: str) -> str:
54
- """
55
- hình Extractive + Abstract - PLACEHOLDER.
56
- # TODO: Thay thế bằng code thật (thường dùng 2 bước: extractive → abstractive)
57
- """
58
- # Mock để demo
59
- return f"🔹 TÓM TẮT EXTRACTIVE + ABSTRACT:\n{van_ban_da_xu_ly[:200]}..."
60
-
61
-
62
- # ====================== HÀM CHÍNH ======================
63
- def tom_tat_van_ban(
64
- van_ban: str,
65
- phuong_phap_tien_xu_ly: str,
66
- mo_hinh: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ) -> str:
68
- """
69
- Hàm chính kết nối toàn bộ luồng:
70
- Văn bản gốc Tiền xử lý → Mô hình → Kết quả tóm tắt.
71
- """
72
- # Bước 1: Tiền xử lý
73
- if phuong_phap_tien_xu_ly == "Tiền xử lý theo câu":
74
- van_ban_da_xu_ly = tien_xu_ly_theo_cau(van_ban)
75
  else:
76
- van_ban_da_xu_ly = tien_xu_ly_theo_ngu_nghia(van_ban)
77
 
78
- # Bước 2: Chọn mô hình tóm tắt
79
- if mo_hinh == " hình baseline":
80
- ket_qua = mo_hinh_baseline(van_ban_da_xu_ly)
81
- elif mo_hinh == " hình baseline Extractive":
82
- ket_qua = mo_hinh_baseline_extractive(van_ban_da_xu_ly)
83
- else: # Mô hình Extractive và Abstract
84
- ket_qua = mo_hinh_extractive_abstract(van_ban_da_xu_ly)
85
 
86
- return ket_qua
87
 
88
 
89
- # ====================== GIAO DIỆN GRADIO ======================
90
  with gr.Blocks(
91
- title="Hệ thống Tóm tắt Văn bản Tự động",
92
  theme=gr.themes.Soft(),
93
  css="""
94
  .gradio-container {max-width: 1200px; margin: auto;}
@@ -97,88 +135,88 @@ with gr.Blocks(
97
  ) as demo:
98
  gr.Markdown(
99
  """
100
- # 🚀 Hệ thống Tóm tắt Văn bản Tự động
101
- **Nhập văn bản cần tóm tắt Chọn phương pháp & hình Nhận kết quả ngay lập tức**
102
  """
103
  )
104
 
105
  with gr.Row():
106
  with gr.Column(scale=3):
107
  input_text = gr.Textbox(
108
- label="📝 Văn bản cần tóm tắt",
109
- placeholder="Dán đoạn văn bản dài vào đây ( thể vài nghìn từ)...",
110
  lines=12,
111
  max_lines=30,
112
  show_copy_button=True
113
  )
114
 
115
  with gr.Column(scale=1):
116
- gr.Markdown("### ⚙️ Cài đặt")
117
 
118
- phuong_phap = gr.Radio(
119
  choices=[
120
- "Tiền xử lý theo câu",
121
- "Tiền xử lý theo ngữ nghĩa"
122
  ],
123
- value="Tiền xử lý theo câu",
124
- label="Phương pháp Tiền xử lý",
125
- info="Chọn cách làm sạch văn bản trước khi tóm tắt"
126
  )
127
 
128
- mo_hinh = gr.Radio(
129
  choices=[
130
- " hình baseline",
131
- " hình baseline Extractive",
132
- "Mô hình Extractive Abstract"
133
  ],
134
- value=" hình baseline",
135
- label=" hình Tóm tắt",
136
- info="Chọn hình bạn muốn sử dụng"
137
  )
138
 
139
  with gr.Row():
140
  btn_tom_tat = gr.Button(
141
- "🔍 Tóm tắt ngay",
142
  variant="primary",
143
  size="large"
144
  )
145
 
146
  output_text = gr.Textbox(
147
- label="📄 Kết quả tóm tắt",
148
  lines=10,
149
- placeholder="Kết quả sẽ hiển thị ở đây...",
150
  show_copy_button=True
151
  )
152
 
153
- # Kết nối nút bấm
154
  btn_tom_tat.click(
155
- fn=tom_tat_van_ban,
156
- inputs=[input_text, phuong_phap, mo_hinh],
157
  outputs=output_text
158
  )
159
 
160
- # Ví dụ minh họa
161
  gr.Examples(
162
  examples=[
163
- [" Nội thủ đô của Việt Nam. Thành phố này lịch sử hơn 1000 năm. Dân số khoảng 8 triệu người. Đây trung tâm chính trị, kinh tế văn hóa của cả nước."],
164
- ["Trí tuệ nhân tạo (AI) đang thay đổi thế giới. Nhiều công ty lớn đang đầu mạnh vào lĩnh vực này. Gradio giúp lập trình viên nhanh chóng xây dựng giao diện web cho các mô hình AI."],
165
  ],
166
  inputs=input_text,
167
- label="📌 dụ thử nghiệm"
168
  )
169
 
170
  gr.Markdown(
171
  """
172
  ---
173
- 💡 **Hướng dẫn sử dụng**:
174
- 1. Dán văn bản cần tóm tắt vào ô bên trái.
175
- 2. Chọn phương pháp tiền xử lý và mô hình.
176
- 3. Nhấn **Tóm tắt ngay**.
177
- Kết quả sẽ xuất hiện ngay lập tức (hiện đang dùng mock để demo).
178
  """
179
  )
180
 
181
 
182
- # Khởi chạy ứng dụng
183
  if __name__ == "__main__":
184
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import BartForConditionalGeneration, BartTokenizer
4
+ import re
5
+ import numpy as np
6
+ import networkx as nx
7
+ from typing import List, Dict
8
+ from src.utils.get_model import get_summarizer
9
+ from src.preprocessing.edu_sentences import preprocess_external_text
10
+ from src.utils.get_model import get_extractive_model
11
+ from src.model.baseline_extractive_model import get_trigrams
12
+
13
+ REPO_ID_baseline_model = "Reality8081/bart-base"
14
+ REPO_ID_baseline_model_edu = "Reality8081/bart-base-edu"
15
+ REPO_ID_baseline_extractive_model = "Reality8081/bart-extractive"
16
+ REPO_ID_baseline_extractive_model_edu = "Reality8081/bart-extractive-edu"
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ def model_baseline(prepro_dict: Dict) -> str:
21
+ text_to_summarize = prepro_dict.get("article", "")
22
+
23
+ if not text_to_summarize.strip():
24
+ return "Lỗi: Không tìm thấy văn bản để tóm tắt."
25
+
26
+ segmentation_method = prepro_dict.get("segmentation_method")
27
+ if segmentation_method == "edu":
28
+ repo_id = REPO_ID_baseline_model_edu
29
+ else:
30
+ repo_id = REPO_ID_baseline_model
31
+ summarizer = get_summarizer(repo_id)
32
+ summary = summarizer.summarize(text_to_summarize)
33
+ return summary
34
+
35
+
36
+
37
+ def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> str:
38
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
39
+ segments = prepro_dict["segments"]
40
+ if not segments:
41
+ return "Không thể phân tách văn bản thành các câu/EDU."
42
+ input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
43
+ attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
44
+ segmentation_method = prepro_dict.get("segmentation_method")
45
+ if segmentation_method == "edu":
46
+ repo_id = REPO_ID_baseline_extractive_model_edu
47
+ else:
48
+ repo_id = REPO_ID_baseline_extractive_model
49
+ model = get_summarizer(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
50
+ with torch.no_grad():
51
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
52
+ # Sử dụng Sigmoid đưa logit về khoảng (0, 1) để lấy xác suất
53
+ probs = torch.sigmoid(outputs['logits']).squeeze(0).cpu().numpy()
54
+
55
+ segment_scores = []
56
+ current_idx = 1 # Bỏ qua token đặc biệt <s> đầu chuỗi
57
+
58
+ for seg in segments:
59
+ # Lấy số lượng token của đoạn hiện tại
60
+ seg_len = len(tokenizer.encode(seg, add_special_tokens=False))
61
+ end_idx = min(current_idx + seg_len, len(probs))
62
+
63
+ if current_idx < len(probs):
64
+ # Điểm của câu là trung bình cộng xác suất các token bên trong nó
65
+ seg_score = np.mean(probs[current_idx:end_idx])
66
+ else:
67
+ seg_score = 0.0
68
+
69
+ segment_scores.append(seg_score)
70
+ current_idx += seg_len
71
+
72
+ ranked_indices = np.argsort(segment_scores)[::-1] # Sắp xếp giảm dần theo điểm
73
+
74
+ selected_indices = []
75
+ selected_trigrams = set()
76
+
77
+ for idx in ranked_indices:
78
+ candidate_seg = segments[idx]
79
+ candidate_trigrams = get_trigrams(candidate_seg)
80
+
81
+ # Chỉ chọn nếu segment này không trùng lặp các cụm từ (trigram) so với những đoạn đã chọn
82
+ if not candidate_trigrams.intersection(selected_trigrams):
83
+ selected_indices.append(idx)
84
+ selected_trigrams.update(candidate_trigrams)
85
+
86
+ # Dừng lại nếu đã gom đủ số câu theo yêu cầu
87
+ if len(selected_indices) == top_n:
88
+ break
89
+
90
+ # Sắp xếp lại thứ tự index xuất hiện của các câu trong văn bản gốc để tóm tắt được mạch lạc
91
+ selected_indices = sorted(selected_indices)
92
+
93
+ # Lắp ráp kết quả
94
+ extractive_summary = " ".join([segments[i] for i in selected_indices if i < len(segments)])
95
+ return extractive_summary
96
+
97
+ def model_extractive_abstract(prepro_dict: Dict) -> str:
98
+
99
+ return prepro_dict
100
+
101
+
102
+ # ====================== MAIN FUNCTION ======================
103
+ def ATS(
104
+ text: str,
105
+ segmentation_method: str = None,
106
+ model: str = None,
107
+ reference_summary: str = None
108
  ) -> str:
109
+ """Main workflow: Raw Text → Preprocessing → Model"""
110
+ # Step 1: Preprocessing
111
+ if segmentation_method == "Sentence-based Preprocessing":
112
+ prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='sentence')
 
 
 
113
  else:
114
+ prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu')
115
 
116
+ # Step 2: Chọn model
117
+ if model == "Baseline Model: TextRank + Vanilla BART":
118
+ result = model_baseline(prepro_dict)
119
+ elif model == "Baseline Model with Extractive":
120
+ result = model_baseline_extractive(prepro_dict)
121
+ else:
122
+ result = model_extractive_abstract(prepro_dict)
123
 
124
+ return result
125
 
126
 
127
+ # ====================== GRADIO INTERFACE ======================
128
  with gr.Blocks(
129
+ title="Automated Text Summarization System",
130
  theme=gr.themes.Soft(),
131
  css="""
132
  .gradio-container {max-width: 1200px; margin: auto;}
 
135
  ) as demo:
136
  gr.Markdown(
137
  """
138
+ # 🚀 Automated Text Summarization System
139
+ **Input textSelect method & modelGet results instantly**
140
  """
141
  )
142
 
143
  with gr.Row():
144
  with gr.Column(scale=3):
145
  input_text = gr.Textbox(
146
+ label="📝 Text to Summarize",
147
+ placeholder="Paste your long text here (up to several thousand words)...",
148
  lines=12,
149
  max_lines=30,
150
  show_copy_button=True
151
  )
152
 
153
  with gr.Column(scale=1):
154
+ gr.Markdown("### ⚙️ Settings")
155
 
156
+ method = gr.Radio(
157
  choices=[
158
+ "Sentence-based Preprocessing",
159
+ "Semantic-based Preprocessing (EDU)"
160
  ],
161
+ value="Sentence-based Preprocessing",
162
+ label="Preprocessing Method",
163
+ info="Choose how to clean the text before summarization"
164
  )
165
 
166
+ model = gr.Radio(
167
  choices=[
168
+ "Baseline Model",
169
+ "Baseline Model with Extractive",
170
+ "Extractive and Abstractive Model"
171
  ],
172
+ value="Baseline Model",
173
+ label="Summarization Model",
174
+ info="Select the model you want to use"
175
  )
176
 
177
  with gr.Row():
178
  btn_tom_tat = gr.Button(
179
+ "🔍 Summarize Now",
180
  variant="primary",
181
  size="large"
182
  )
183
 
184
  output_text = gr.Textbox(
185
+ label="📄 Summary Result",
186
  lines=10,
187
+ placeholder="The result will appear here...",
188
  show_copy_button=True
189
  )
190
 
191
+ # Connect button click
192
  btn_tom_tat.click(
193
+ fn=ATS,
194
+ inputs=[input_text, method, model],
195
  outputs=output_text
196
  )
197
 
198
+ # Examples
199
  gr.Examples(
200
  examples=[
201
+ ["Hanoi is the capital of Vietnam. The city has a history of over 1000 years. The population is about 8 million people. It is the political, economic, and cultural center of the country."],
202
+ ["Artificial Intelligence (AI) is changing the world. Many large companies are investing heavily in this field. Gradio helps developers quickly build web interfaces for AI models."],
203
  ],
204
  inputs=input_text,
205
+ label="📌 Test Examples"
206
  )
207
 
208
  gr.Markdown(
209
  """
210
  ---
211
+ 💡 **User Guide**:
212
+ 1. Paste the text to be summarized into the left box.
213
+ 2. Select the preprocessing method and model.
214
+ 3. Click **Summarize Now**.
215
+ The result will appear instantly (currently using mock data for demo).
216
  """
217
  )
218
 
219
 
220
+ # Launch the app
221
  if __name__ == "__main__":
222
  demo.launch()
src/model/baseline_extractive_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from transformers import BartModel, BartTokenizer
5
+
6
+ class BartExtractiveSummarizer(nn.Module):
7
+ def __init__(self, model_name="facebook/bart-large"):
8
+ super(BartExtractiveSummarizer, self).__init__()
9
+ self.encoder = BartModel.from_pretrained(model_name).encoder
10
+ hidden_size = self.encoder.config.hidden_size
11
+ self.classifier = nn.Linear(hidden_size, 1)
12
+
13
+ def forward(self, input_ids, attention_mask, saliency_mask=None, **kwargs):
14
+ encoder_outputs = self.encoder(
15
+ input_ids=input_ids,
16
+ attention_mask=attention_mask
17
+ )
18
+ hidden_states = encoder_outputs.last_hidden_state
19
+ logits = self.classifier(hidden_states).squeeze(-1)
20
+
21
+ loss = None
22
+ if saliency_mask is not None:
23
+ active_loss = attention_mask.view(-1) == 1
24
+ active_logits = logits.view(-1)[active_loss]
25
+ active_labels = saliency_mask.view(-1)[active_loss].float()
26
+
27
+ # --- TỐI ƯU: TỰ ĐỘNG TÍNH CLASS WEIGHT CHO TỪNG BATCH ---
28
+ num_pos = active_labels.sum()
29
+ num_neg = active_labels.size(0) - num_pos
30
+
31
+ if num_pos > 0:
32
+ weight = torch.clamp(num_neg / num_pos, max=10.0)
33
+ else:
34
+ weight = torch.tensor(1.0).to(logits.device)
35
+
36
+ loss_fct = nn.BCEWithLogitsLoss(pos_weight=weight)
37
+ loss = loss_fct(active_logits, active_labels)
38
+
39
+ return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
40
+
41
+ def get_trigrams(text: str):
42
+ """Tạo tập hợp các cụm 3 từ liên tiếp từ một đoạn văn bản (Trigram Blocking)"""
43
+ words = text.lower().split()
44
+ return set(tuple(words[i:i+3]) for i in range(len(words)-2))
45
+
src/model/baseline_model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ from sklearn.feature_extraction.text import TfidfVectorizer
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ import torch
5
+ from transformers import BartForConditionalGeneration, BartTokenizer
6
+
7
+ def textrank_summarize(sentences, top_n=3):
8
+ """
9
+ Trích xuất câu quan trọng bằng TextRank + TF-IDF.
10
+ Đầu vào 'sentences' là một list các câu (hoặc EDUs) trong một văn bản.
11
+ """
12
+ # Xử lý trường hợp bài báo quá ngắn
13
+ if len(sentences) <= top_n:
14
+ return " ".join(sentences)
15
+
16
+ try:
17
+ # Bước 1: Khởi tạo TfidfVectorizer và fit_transform tập sentences của 1 bài báo
18
+ vectorizer = TfidfVectorizer(stop_words='english')
19
+ tfidf_matrix = vectorizer.fit_transform(sentences)
20
+
21
+ # Bước 2: Tính ma trận tương đồng Cosine
22
+ similarity_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix)
23
+
24
+ # Bước 3: Đưa ma trận vào networkx tạo đồ thị và tính PageRank
25
+ nx_graph = nx.from_numpy_array(similarity_matrix)
26
+ scores = nx.pagerank(nx_graph)
27
+
28
+ # Bước 4: Sắp xếp điểm số và chọn top_n câu
29
+ ranked_sentences = sorted(((scores[i], s) for i, s in enumerate(sentences)), reverse=True)
30
+
31
+ # Giữ đúng thứ tự xuất hiện của câu trong văn bản gốc để dễ đọc
32
+ top_sentences_indices = sorted([sentences.index(ranked_sentences[i][1]) for i in range(top_n)])
33
+ summary = " ".join([sentences[i] for i in top_sentences_indices])
34
+ return summary
35
+
36
+ except Exception as e:
37
+ # Fallback về Lead-N nếu đồ thị lỗi (do câu rỗng hoặc không có từ vựng)
38
+ return " ".join(sentences[:top_n])
39
+
40
+
41
+ class BartSummarizer:
42
+ def __init__(self, model_path="facebook/bart-base"):
43
+ """
44
+ Khởi tạo mô hình và tokenizer.
45
+ model_path có thể là repo trên Hugging Face hoặc đường dẫn local chứa weights.
46
+ """
47
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ print(f"Loading BART model from '{model_path}' onto {self.device}...")
49
+
50
+ self.tokenizer = BartTokenizer.from_pretrained(model_path)
51
+ self.model = BartForConditionalGeneration.from_pretrained(model_path)
52
+ self.model.to(self.device)
53
+ self.model.eval() # Chuyển sang chế độ inference ngay từ đầu
54
+
55
+ def summarize(self, text, max_input_length=512, max_output_length=128, min_output_length=30):
56
+ """
57
+ Hàm sinh tóm tắt cho một đoạn văn bản đầu vào.
58
+ """
59
+ with torch.no_grad():
60
+ # Cắt ngắn đầu vào để chống quá tải GPU, đồng bộ với lúc train
61
+ inputs = self.tokenizer(
62
+ text,
63
+ max_length=max_input_length,
64
+ truncation=True,
65
+ padding=True,
66
+ return_tensors="pt"
67
+ ).to(self.device)
68
+
69
+ # Sinh văn bản tóm tắt
70
+ summary_ids = self.model.generate(
71
+ input_ids=inputs["input_ids"],
72
+ attention_mask=inputs["attention_mask"],
73
+ max_length=max_output_length,
74
+ min_length=min_output_length,
75
+ num_beams=4,
76
+ length_penalty=2.0, # Ưu tiên sinh câu trọn vẹn
77
+ no_repeat_ngram_size=3, # Chống ảo giác, lặp từ
78
+ early_stopping=True
79
+ )
80
+
81
+ # Decode kết quả về dạng text
82
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
83
+ return summary
84
+
85
+ # Cách gọi trong app:
86
+ # summarizer = BartSummarizer("duong_dan_model_cua_ban_tren_huggingface")
87
+ # result = summarizer.summarize("Đoạn văn bản cần tóm tắt...")
src/preprocessing/edu_sentences.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========================== preprocessing_utils.py ==========================
2
+ import re
3
+ import nltk
4
+ import numpy as np
5
+ import spacy
6
+ from transformers import BartTokenizer
7
+ from rouge_score import rouge_scorer
8
+ from typing import List, Dict, Optional, Union
9
+
10
+ nltk.download('punkt', quiet=True)
11
+ nltk.download('punkt_tab', quiet=True)
12
+
13
+ # Tải SpaCy một lần duy nhất (nhẹ, disable các thành phần không cần)
14
+ nlp = spacy.load("en_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler", "tok2vec"])
15
+
16
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
17
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
18
+
19
+ def clean_text(text: str) -> str:
20
+ """Làm sạch văn bản (dùng chung cho mọi pipeline)"""
21
+ if not isinstance(text, str):
22
+ return ""
23
+ # Xóa URL, email, twitter handle
24
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
25
+ text = re.sub(r'\S+@\S+', '', text)
26
+ text = re.sub(r'@[A-Za-z0-9_]+', '', text)
27
+ # Giữ lại chữ, số, dấu câu cơ bản
28
+ text = re.sub(r'[^\w\s.,;:\'"-?!]', '', text)
29
+ # Chuẩn hóa khoảng trắng
30
+ text = re.sub(r'\s+', ' ', text).strip()
31
+ return text
32
+
33
+
34
+ def segment_text(text: str, method: str = 'sentence') -> tuple[List[str], str]:
35
+ """
36
+ Phân tách văn bản theo phương pháp được chọn.
37
+ Trả về: (list_segments, cleaned_text)
38
+ """
39
+ cleaned = clean_text(text)
40
+
41
+ if method == 'sentence':
42
+ segments = nltk.sent_tokenize(cleaned)
43
+ return segments, cleaned
44
+
45
+ elif method == 'edu':
46
+ # Giống hệt logic notebook EDU (tách câu trước → EDU bằng SpaCy)
47
+ sentences = nltk.sent_tokenize(cleaned)
48
+ processed_docs = list(nlp.pipe(sentences, batch_size=500))
49
+
50
+ all_edus = []
51
+ for doc in processed_docs:
52
+ temp_edus, current_segment = [], []
53
+ for token in doc:
54
+ current_segment.append(token.text_with_ws)
55
+ if (token.pos_ in ["SCONJ", "CCONJ"] or token.text in [",", ";"]) and len(current_segment) > 3:
56
+ temp_edus.append("".join(current_segment).strip())
57
+ current_segment = []
58
+ if current_segment:
59
+ temp_edus.append("".join(current_segment).strip())
60
+ all_edus.extend(temp_edus if temp_edus else [doc.text])
61
+
62
+ return all_edus, cleaned
63
+
64
+ else:
65
+ raise ValueError("method phải là 'sentence' hoặc 'edu'")
66
+
67
+
68
+ def greedy_rouge_selection(segments: List[str], reference_summary: str, top_k: int = 3) -> List[int]:
69
+ """Thuật toán Greedy ROUGE (dùng chung)"""
70
+ selected_indices = []
71
+ best_rouge = 0.0
72
+ if not segments:
73
+ return []
74
+
75
+ for _ in range(min(top_k, len(segments))):
76
+ best_idx = -1
77
+ current_best = best_rouge
78
+
79
+ for i, seg in enumerate(segments):
80
+ if i in selected_indices:
81
+ continue
82
+ candidate = " ".join([segments[j] for j in selected_indices] + [seg])
83
+ scores = scorer.score(reference_summary, candidate)
84
+ avg_f = (scores['rouge1'].fmeasure +
85
+ scores['rouge2'].fmeasure +
86
+ scores['rougeL'].fmeasure) / 3.0
87
+
88
+ if avg_f > current_best:
89
+ current_best = avg_f
90
+ best_idx = i
91
+
92
+ if best_idx != -1:
93
+ selected_indices.append(best_idx)
94
+ best_rouge = current_best
95
+ else:
96
+ break
97
+
98
+ return [1 if i in selected_indices else 0 for i in range(len(segments))]
99
+
100
+
101
+ def create_saliency_mask(input_ids: List[int], segments: List[str],
102
+ ext_labels: List[int], tokenizer) -> List[int]:
103
+ """Tạo Saliency Mask từ segment-level xuống token-level"""
104
+ mask = np.zeros(len(input_ids), dtype=int)
105
+ mask[0] = 1
106
+ if input_ids and input_ids[-1] == tokenizer.eos_token_id:
107
+ mask[-1] = 1
108
+
109
+ current_idx = 1
110
+ for seg_idx, segment in enumerate(segments):
111
+ if current_idx >= len(input_ids) - 1:
112
+ break
113
+ seg_tokens = tokenizer.encode(segment, add_special_tokens=False)
114
+ token_len = len(seg_tokens)
115
+
116
+ if seg_idx < len(ext_labels) and ext_labels[seg_idx] == 1:
117
+ end_idx = min(current_idx + token_len, len(input_ids) - 1)
118
+ mask[current_idx:end_idx] = 1
119
+
120
+ current_idx += token_len
121
+
122
+ return mask.tolist()
123
+
124
+
125
+ def preprocess_external_text(
126
+ text: str,
127
+ reference_summary: Optional[str] = None,
128
+ segmentation_method: str = 'sentence',
129
+ top_k: int = 3,
130
+ max_length: int = 1024
131
+ ) -> Dict:
132
+
133
+ segments, cleaned_article = segment_text(text, method=segmentation_method)
134
+
135
+ inputs = tokenizer(cleaned_article, max_length=max_length, truncation=True, padding=False)
136
+
137
+ result = {
138
+ "article": cleaned_article,
139
+ "segments": segments, # ← list câu hoặc list EDU
140
+ "segmentation_method": segmentation_method,
141
+ "input_ids": inputs["input_ids"],
142
+ "attention_mask": inputs["attention_mask"],
143
+ }
144
+
145
+ # Nếu có tóm tắt tham chiếu → tính nhãn extractive
146
+ if reference_summary is not None:
147
+ ref_clean = clean_text(reference_summary)
148
+ extractive_labels = greedy_rouge_selection(segments, ref_clean, top_k=top_k)
149
+ saliency_mask = create_saliency_mask(inputs["input_ids"], segments, extractive_labels, tokenizer)
150
+
151
+ targets = tokenizer(ref_clean, max_length=128, truncation=True, padding=False)
152
+
153
+ result.update({
154
+ "extractive_labels": extractive_labels,
155
+ "saliency_mask": saliency_mask,
156
+ "labels": targets["input_ids"], # cho phần Abstractive
157
+ "reference_summary": ref_clean
158
+ })
159
+
160
+ return result
161
+
162
+
163
+ def preprocess_batch(
164
+ texts: List[str],
165
+ reference_summaries: Optional[List[str]] = None,
166
+ segmentation_method: str = 'sentence',
167
+ top_k: int = 3
168
+ ) -> List[Dict]:
169
+ """Xử lý nhiều văn bản cùng lúc (dùng cho demo batch)"""
170
+ if reference_summaries is None:
171
+ reference_summaries = [None] * len(texts)
172
+ if len(reference_summaries) != len(texts):
173
+ raise ValueError("Số lượng reference_summaries phải bằng số lượng texts")
174
+
175
+ return [
176
+ preprocess_external_text(txt, ref, segmentation_method, top_k)
177
+ for txt, ref in zip(texts, reference_summaries)
178
+ ]
src/utils/get_model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model.baseline_model import BartSummarizer
2
+ from src.model.baseline_extractive_model import BartExtractiveSummarizer
3
+ loaded_summarizers = {}
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+
7
+
8
+
9
+
10
+ def get_summarizer(repo_id: str):
11
+
12
+ if repo_id not in loaded_summarizers:
13
+ loaded_summarizers[repo_id] = BartSummarizer(model_path=repo_id)
14
+ return loaded_summarizers[repo_id]
15
+
16
+ def get_extractive_model(repo_id: str, base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"):
17
+ """Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub"""
18
+ if repo_id not in loaded_summarizers:
19
+ print(f"Đang tải mô hình Extractive từ repo: {repo_id}...")
20
+
21
+ # Khởi tạo khung kiến trúc trống
22
+ model = BartExtractiveSummarizer(model_name=base_model_name)
23
+
24
+ # Sử dụng hf_hub_download để kéo file trọng số về local cache
25
+ model_path = hf_hub_download(repo_id=repo_id, filename="model_state.bin")
26
+
27
+ # Load trọng số vào model
28
+ model.load_state_dict(torch.load(model_path, map_location=device))
29
+
30
+ model.to(device)
31
+ model.eval() # Chuyển mô hình sang chế độ inference
32
+
33
+ loaded_summarizers[repo_id] = model
34
+
35
+ return loaded_summarizers[repo_id]