DemoApp / src /model /baseline_model.py
Reality8081's picture
Update SRC basline and baseline_extractive
fdcfd15
import networkx as nx
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
def textrank_summarize(sentences, top_n=3):
"""
Trích xuất câu quan trọng bằng TextRank + TF-IDF.
Đầu vào 'sentences' là một list các câu (hoặc EDUs) trong một văn bản.
"""
# Xử lý trường hợp bài báo quá ngắn
if len(sentences) <= top_n:
return " ".join(sentences)
try:
# Bước 1: Khởi tạo TfidfVectorizer và fit_transform tập sentences của 1 bài báo
vectorizer = TfidfVectorizer(stop_words='english')
tfidf_matrix = vectorizer.fit_transform(sentences)
# Bước 2: Tính ma trận tương đồng Cosine
similarity_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix)
# Bước 3: Đưa ma trận vào networkx tạo đồ thị và tính PageRank
nx_graph = nx.from_numpy_array(similarity_matrix)
scores = nx.pagerank(nx_graph)
# Bước 4: Sắp xếp điểm số và chọn top_n câu
ranked_sentences = sorted(((scores[i], s) for i, s in enumerate(sentences)), reverse=True)
# Giữ đúng thứ tự xuất hiện của câu trong văn bản gốc để dễ đọc
top_sentences_indices = sorted([sentences.index(ranked_sentences[i][1]) for i in range(top_n)])
summary = " ".join([sentences[i] for i in top_sentences_indices])
return summary
except Exception as e:
# Fallback về Lead-N nếu đồ thị lỗi (do câu rỗng hoặc không có từ vựng)
return " ".join(sentences[:top_n])
class BartSummarizer:
def __init__(self, model_path="facebook/bart-base"):
"""
Khởi tạo mô hình và tokenizer.
model_path có thể là repo trên Hugging Face hoặc đường dẫn local chứa weights.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading BART model from '{model_path}' onto {self.device}...")
self.tokenizer = BartTokenizer.from_pretrained(model_path)
self.model = BartForConditionalGeneration.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval() # Chuyển sang chế độ inference ngay từ đầu
def summarize(self, text, max_input_length=512, max_output_length=128, min_output_length=30):
"""
Hàm sinh tóm tắt cho một đoạn văn bản đầu vào.
"""
with torch.no_grad():
# Cắt ngắn đầu vào để chống quá tải GPU, đồng bộ với lúc train
inputs = self.tokenizer(
text,
max_length=max_input_length,
truncation=True,
padding=True,
return_tensors="pt"
).to(self.device)
# Sinh văn bản tóm tắt
summary_ids = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_output_length,
min_length=min_output_length,
num_beams=4,
length_penalty=2.0, # Ưu tiên sinh câu trọn vẹn
no_repeat_ngram_size=3, # Chống ảo giác, lặp từ
early_stopping=True
)
# Decode kết quả về dạng text
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# Cách gọi trong app:
# summarizer = BartSummarizer("duong_dan_model_cua_ban_tren_huggingface")
# result = summarizer.summarize("Đoạn văn bản cần tóm tắt...")