import os import gradio as gr import torch from transformers import BartForConditionalGeneration, BartTokenizer import re import numpy as np import networkx as nx import plotly.graph_objects as go from typing import List, Dict from src.utils.get_model import get_summarizer, get_extractive_model, get_extractive_abstractive from src.preprocessing.edu_sentences import preprocess_external_text from src.model.baseline_extractive_model import get_trigrams from typing import List, Dict, Tuple REPO_ID_baseline_model = "Reality8081/bart-base" REPO_ID_baseline_model_edu = "Reality8081/bart-base-edu" REPO_ID_baseline_extractive_model = "Reality8081/bart_extractive" REPO_ID_baseline_extractive_model_edu = "Reality8081/bart_extractive-edu" REPO_ID_Extabs_model = "Reality8081/bart-encoder-decoder" REPO_ID_Extabs_model_edu = "Reality8081/bart-encoder-decoder-edu" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def create_saliency_plot(segment_scores: List[float], selected_indices: List[int], segments: List[str]) -> go.Figure: """Tạo biểu đồ cột thể hiện Saliency Score của từng Segment""" colors = ['#EF4444' if i in selected_indices else '#3B82F6' for i in range(len(segment_scores))] labels = ["Chọn tóm tắt" if i in selected_indices else "Bỏ qua" for i in range(len(segment_scores))] # Rút gọn text để hiển thị khi hover chuột hover_texts = [f"Segment {i}
Score: {score:.3f}
Text: {seg[:60]}..." for i, (score, seg) in enumerate(zip(segment_scores, segments))] fig = go.Figure(data=[go.Bar( x=[f"Seg {i}" for i in range(len(segment_scores))], y=segment_scores, marker_color=colors, text=[f"{s:.2f}" for s in segment_scores], textposition='auto', hovertext=hover_texts, hoverinfo="text" )]) fig.update_layout( title="Saliency Scores trên từng Câu/EDU", xaxis_title="Vị trí Câu / EDU", yaxis_title="Saliency Score (0-1)", template="plotly_white", margin=dict(l=40, r=40, t=40, b=40), height=350 ) return fig def model_baseline(prepro_dict: Dict) -> Tuple[str, float, go.Figure]: text_to_summarize = prepro_dict.get("article", "") if not text_to_summarize.strip(): return "Lỗi: Không tìm thấy văn bản để tóm tắt." segmentation_method = prepro_dict.get("segmentation_method") if segmentation_method == "edu": repo_id = REPO_ID_baseline_model_edu else: repo_id = REPO_ID_baseline_model summarizer = get_summarizer(repo_id) summary = summarizer.summarize(text_to_summarize) return summary, None, None def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> Tuple[str, float, go.Figure]: tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") segments = prepro_dict["segments"] if not segments: raise ValueError("Không thể phân tách văn bản thành các câu/EDU.") input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device) attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device) segmentation_method = prepro_dict.get("segmentation_method") if segmentation_method == "edu": repo_id = REPO_ID_baseline_extractive_model_edu else: repo_id = REPO_ID_baseline_extractive_model model = get_extractive_model(repo_id=repo_id, device=device) model = model.to(torch.float32) model.eval() with torch.no_grad(), torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu"): outputs = model(input_ids=input_ids, attention_mask=attention_mask) # Sử dụng Sigmoid đưa logit về khoảng (0, 1) để lấy xác suất logits = outputs['logits'].to(torch.float32) probs = torch.sigmoid(logits).squeeze(0).cpu().numpy() segment_scores = [] current_idx = 1 # Bỏ qua token đặc biệt ở đầu chuỗi for seg in segments: # Lấy số lượng token của đoạn hiện tại seg_len = len(tokenizer.encode(seg, add_special_tokens=False)) end_idx = min(current_idx + seg_len, len(probs)) if current_idx < len(probs): # Điểm của câu là trung bình cộng xác suất các token bên trong nó seg_score = np.mean(probs[current_idx:end_idx]) else: seg_score = 0.0 segment_scores.append(seg_score) current_idx += seg_len ranked_indices = np.argsort(segment_scores)[::-1] # Sắp xếp giảm dần theo điểm selected_indices = [] selected_trigrams = set() for idx in ranked_indices: candidate_seg = segments[idx] candidate_trigrams = get_trigrams(candidate_seg) # Chỉ chọn nếu segment này không trùng lặp các cụm từ (trigram) so với những đoạn đã chọn if not candidate_trigrams.intersection(selected_trigrams): selected_indices.append(idx) selected_trigrams.update(candidate_trigrams) # Dừng lại nếu đã gom đủ số câu theo yêu cầu if len(selected_indices) == top_n: break # Sắp xếp lại thứ tự index xuất hiện của các câu trong văn bản gốc để tóm tắt được mạch lạc selected_indices = sorted(selected_indices) # Lắp ráp kết quả extractive_summary = " ".join([segments[i] for i in selected_indices if i < len(segments)]) avg_confidence = float(np.mean([segment_scores[i] for i in selected_indices])) if selected_indices else 0.0 fig = create_saliency_plot(segment_scores, selected_indices, segments) return extractive_summary, avg_confidence, fig def model_extractive_abstract(prepro_dict: Dict) -> Tuple[str, float, go.Figure]: tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") segments = prepro_dict["segments"] input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device) attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device) segmentation_method = prepro_dict.get("segmentation_method") if segmentation_method == "edu": repo_id = REPO_ID_Extabs_model_edu else: repo_id = REPO_ID_Extabs_model model = get_extractive_abstractive(repo_id=repo_id, base_model_name="facebook/bart-large", device=device) with torch.no_grad(): encoder_outputs = model.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_states = encoder_outputs.last_hidden_state ext_logits = model.ext_head(hidden_states).squeeze(-1) probs = torch.sigmoid(ext_logits).squeeze(0).cpu().numpy() segment_scores = [] current_idx = 1 # Bỏ qua token đặc biệt ở đầu chuỗi for seg in segments: seg_len = len(tokenizer.encode(seg, add_special_tokens=False)) end_idx = min(current_idx + seg_len, len(probs)) if current_idx < len(probs): seg_score = float(np.mean(probs[current_idx:end_idx])) else: seg_score = 0.0 segment_scores.append(seg_score) current_idx += seg_len selected_indices = [i for i, score in enumerate(segment_scores) if score >= 0.5] avg_confidence = float(np.mean([segment_scores[i] for i in selected_indices])) if selected_indices else 0.0 summary_ids = model.generate_summary( input_ids=input_ids, attention_mask=attention_mask, max_length=150, min_length=40, num_beams=4, early_stopping=True ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) fig = create_saliency_plot(segment_scores, selected_indices, segments) # Hiện tại chưa chọn câu nào nên selected_indices rỗng return summary, avg_confidence, fig # ====================== MAIN FUNCTION ====================== def ATS( text: str, segmentation_method: str = None, model: str = None, reference_summary: str = None ) -> str: if not text or len(text.split()) < 5: return "Văn bản quá ngắn hoặc rỗng, vui lòng nhập thêm nội dung.", "⚠️ Không có dữ liệu", None try: if segmentation_method == "Sentence-based Preprocessing": prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='sentence') else: prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu') # Step 2: Chọn model if model == "Baseline Model": result, avg_conf, plot_fig = model_baseline(prepro_dict) elif model == "Baseline Model with Extractive": result, avg_conf, plot_fig = model_baseline_extractive(prepro_dict) else: result, avg_conf, plot_fig = model_extractive_abstract(prepro_dict) origin_words = len(text.split()) sum_words = len(result.split()) comp_ratio = (sum_words / origin_words * 100) if origin_words > 0 else 0 conf_str = f"**{avg_conf * 100:.2f}%**" if avg_conf is not None else "*N/A (Không tính Saliency Score)*" metrics_md = f""" ### 📊 Thống kê kết quả - **Tổng số từ văn bản gốc:** {origin_words} từ - **Số từ bản tóm tắt:** {sum_words} từ - **Tỉ lệ nén:** {comp_ratio:.1f}% *(Giữ lại {comp_ratio:.1f}% dung lượng gốc)* - **Độ tin cậy trung bình (Confidence):** {conf_str} """ if plot_fig is None: plot_fig = go.Figure() plot_fig.update_layout(title="Mô hình Baseline không hỗ trợ Saliency Plot.", template="plotly_white") return result, metrics_md, plot_fig except Exception as e: return f"Đã xảy ra lỗi: {str(e)}", "⚠️ Lỗi hệ thống", None # ====================== GRADIO INTERFACE ====================== with gr.Blocks( title="Automated Text Summarization System", ) as demo: gr.Markdown( """ # 🚀 Automated Text Summarization System **Input text → Select method & model → Get results instantly** """ ) with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox( label="📝 Text to Summarize", placeholder="Paste your long text here (up to several thousand words)...", lines=15, max_lines=30, ) with gr.Row(): btn_summary = gr.Button( "🔍 Summarize Now", variant="primary", size="large" ) with gr.Column(scale=1): gr.Markdown("### ⚙️ Settings") method = gr.Radio( choices=[ "Sentence-based Preprocessing", "Semantic-based Preprocessing (EDU)" ], value="Sentence-based Preprocessing", label="Preprocessing Method", info="Choose how to clean the text before summarization" ) model = gr.Radio( choices=[ "Baseline Model", "Baseline Model with Extractive", "Extractive and Abstractive Model" ], value="Baseline Model", label="Summarization Model", info="Select the model you want to use" ) output_metrics = gr.Markdown("### 📊 Thống kê kết quả\n*Chờ xử lý...*") gr.Markdown("---") with gr.Row(): with gr.Column(scale=1): output_text = gr.Textbox( label="📄 Summary Result", lines=10, placeholder="The result will appear here...", ) with gr.Column(scale=1): output_plot = gr.Plot(label="📈 Saliency Score Plot") # Connect button click btn_summary.click( fn=ATS, inputs=[input_text, method, model], outputs=[output_text, output_metrics, output_plot] ) # Examples gr.Examples( examples=[ ["Hanoi is the capital of Vietnam. The city has a history of over 1000 years. The population is about 8 million people. It is the political, economic, and cultural center of the country."], ["Artificial Intelligence (AI) is changing the world. Many large companies are investing heavily in this field. Gradio helps developers quickly build web interfaces for AI models."], ], inputs=input_text, label="📌 Test Examples" ) gr.Markdown( """ --- 💡 **User Guide**: 1. Paste the text to be summarized into the left box. 2. Select the preprocessing method and model. 3. Click **Summarize Now**. The result will appear instantly (currently using mock data for demo). """ ) # Launch the app if __name__ == "__main__": demo.launch(theme=gr.themes.Soft(), css=""" .gradio-container {max-width: 1200px; margin: auto;} .title {text-align: center; margin-bottom: 10px;} """, share=True, ssr_mode=False )