DemoApp / src /preprocessing /edu_sentences.py
Reality8081's picture
Update SRC basline and baseline_extractive
fdcfd15
# ========================== preprocessing_utils.py ==========================
import re
import nltk
import numpy as np
import spacy
from transformers import BartTokenizer
from rouge_score import rouge_scorer
from typing import List, Dict, Optional, Union
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
# Tải SpaCy một lần duy nhất (nhẹ, disable các thành phần không cần)
nlp = spacy.load("en_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler", "tok2vec"])
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
def clean_text(text: str) -> str:
"""Làm sạch văn bản (dùng chung cho mọi pipeline)"""
if not isinstance(text, str):
return ""
# Xóa URL, email, twitter handle
text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
text = re.sub(r'\S+@\S+', '', text)
text = re.sub(r'@[A-Za-z0-9_]+', '', text)
# Giữ lại chữ, số, dấu câu cơ bản
text = re.sub(r'[^\w\s.,;:\'"-?!]', '', text)
# Chuẩn hóa khoảng trắng
text = re.sub(r'\s+', ' ', text).strip()
return text
def segment_text(text: str, method: str = 'sentence') -> tuple[List[str], str]:
"""
Phân tách văn bản theo phương pháp được chọn.
Trả về: (list_segments, cleaned_text)
"""
cleaned = clean_text(text)
if method == 'sentence':
segments = nltk.sent_tokenize(cleaned)
return segments, cleaned
elif method == 'edu':
# Giống hệt logic notebook EDU (tách câu trước → EDU bằng SpaCy)
sentences = nltk.sent_tokenize(cleaned)
processed_docs = list(nlp.pipe(sentences, batch_size=500))
all_edus = []
for doc in processed_docs:
temp_edus, current_segment = [], []
for token in doc:
current_segment.append(token.text_with_ws)
if (token.pos_ in ["SCONJ", "CCONJ"] or token.text in [",", ";"]) and len(current_segment) > 3:
temp_edus.append("".join(current_segment).strip())
current_segment = []
if current_segment:
temp_edus.append("".join(current_segment).strip())
all_edus.extend(temp_edus if temp_edus else [doc.text])
return all_edus, cleaned
else:
raise ValueError("method phải là 'sentence' hoặc 'edu'")
def greedy_rouge_selection(segments: List[str], reference_summary: str, top_k: int = 3) -> List[int]:
"""Thuật toán Greedy ROUGE (dùng chung)"""
selected_indices = []
best_rouge = 0.0
if not segments:
return []
for _ in range(min(top_k, len(segments))):
best_idx = -1
current_best = best_rouge
for i, seg in enumerate(segments):
if i in selected_indices:
continue
candidate = " ".join([segments[j] for j in selected_indices] + [seg])
scores = scorer.score(reference_summary, candidate)
avg_f = (scores['rouge1'].fmeasure +
scores['rouge2'].fmeasure +
scores['rougeL'].fmeasure) / 3.0
if avg_f > current_best:
current_best = avg_f
best_idx = i
if best_idx != -1:
selected_indices.append(best_idx)
best_rouge = current_best
else:
break
return [1 if i in selected_indices else 0 for i in range(len(segments))]
def create_saliency_mask(input_ids: List[int], segments: List[str],
ext_labels: List[int], tokenizer) -> List[int]:
"""Tạo Saliency Mask từ segment-level xuống token-level"""
mask = np.zeros(len(input_ids), dtype=int)
mask[0] = 1
if input_ids and input_ids[-1] == tokenizer.eos_token_id:
mask[-1] = 1
current_idx = 1
for seg_idx, segment in enumerate(segments):
if current_idx >= len(input_ids) - 1:
break
seg_tokens = tokenizer.encode(segment, add_special_tokens=False)
token_len = len(seg_tokens)
if seg_idx < len(ext_labels) and ext_labels[seg_idx] == 1:
end_idx = min(current_idx + token_len, len(input_ids) - 1)
mask[current_idx:end_idx] = 1
current_idx += token_len
return mask.tolist()
def preprocess_external_text(
text: str,
reference_summary: Optional[str] = None,
segmentation_method: str = 'sentence',
top_k: int = 3,
max_length: int = 1024
) -> Dict:
segments, cleaned_article = segment_text(text, method=segmentation_method)
inputs = tokenizer(cleaned_article, max_length=max_length, truncation=True, padding=False)
result = {
"article": cleaned_article,
"segments": segments, # ← list câu hoặc list EDU
"segmentation_method": segmentation_method,
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
}
# Nếu có tóm tắt tham chiếu → tính nhãn extractive
if reference_summary is not None:
ref_clean = clean_text(reference_summary)
extractive_labels = greedy_rouge_selection(segments, ref_clean, top_k=top_k)
saliency_mask = create_saliency_mask(inputs["input_ids"], segments, extractive_labels, tokenizer)
targets = tokenizer(ref_clean, max_length=128, truncation=True, padding=False)
result.update({
"extractive_labels": extractive_labels,
"saliency_mask": saliency_mask,
"labels": targets["input_ids"], # cho phần Abstractive
"reference_summary": ref_clean
})
return result
def preprocess_batch(
texts: List[str],
reference_summaries: Optional[List[str]] = None,
segmentation_method: str = 'sentence',
top_k: int = 3
) -> List[Dict]:
"""Xử lý nhiều văn bản cùng lúc (dùng cho demo batch)"""
if reference_summaries is None:
reference_summaries = [None] * len(texts)
if len(reference_summaries) != len(texts):
raise ValueError("Số lượng reference_summaries phải bằng số lượng texts")
return [
preprocess_external_text(txt, ref, segmentation_method, top_k)
for txt, ref in zip(texts, reference_summaries)
]