Reality8081 commited on
Commit
1dde759
·
1 Parent(s): d71e8ac

Update src

Browse files
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from aiofiles import os
2
  import gradio as gr
3
  import torch
4
  from transformers import BartForConditionalGeneration, BartTokenizer
@@ -6,17 +6,17 @@ import re
6
  import numpy as np
7
  import networkx as nx
8
  from typing import List, Dict
9
- from src.utils.get_model import get_summarizer
10
  from src.preprocessing.edu_sentences import preprocess_external_text
11
- from src.utils.get_model import get_extractive_model
12
  from src.model.baseline_extractive_model import get_trigrams
13
-
14
- hf_token = os.environ.get("HF_TOKEN")
 
15
 
16
  REPO_ID_baseline_model = "Reality8081/bart-base"
17
  REPO_ID_baseline_model_edu = "Reality8081/bart-base-edu"
18
- REPO_ID_baseline_extractive_model = "Reality8081/bart-extractive"
19
- REPO_ID_baseline_extractive_model_edu = "Reality8081/bart-extractive-edu"
20
  REPO_ID_Extabs_model = "Reality8081/bart-encoder-decoder"
21
  REPO_ID_Extabs_model_edu = "Reality8081/bart-encoder-decoder-edu"
22
 
@@ -52,12 +52,14 @@ def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> str:
52
  repo_id = REPO_ID_baseline_extractive_model_edu
53
  else:
54
  repo_id = REPO_ID_baseline_extractive_model
55
- model = get_summarizer(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
56
- with torch.no_grad():
 
 
57
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
58
  # Sử dụng Sigmoid đưa logit về khoảng (0, 1) để lấy xác suất
59
- probs = torch.sigmoid(outputs['logits']).squeeze(0).cpu().numpy()
60
-
61
  segment_scores = []
62
  current_idx = 1 # Bỏ qua token đặc biệt <s> ở đầu chuỗi
63
 
@@ -109,7 +111,7 @@ def model_extractive_abstract(prepro_dict: Dict) -> str:
109
  repo_id = REPO_ID_Extabs_model_edu
110
  else:
111
  repo_id = REPO_ID_Extabs_model
112
- model = get_summarizer(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
113
  with torch.no_grad():
114
  summary_ids = model.generate_summary(
115
  input_ids=input_ids,
@@ -140,7 +142,7 @@ def ATS(
140
  prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu')
141
 
142
  # Step 2: Chọn model
143
- if model == "Baseline Model: TextRank + Vanilla BART":
144
  result = model_baseline(prepro_dict)
145
  elif model == "Baseline Model with Extractive":
146
  result = model_baseline_extractive(prepro_dict)
@@ -153,11 +155,7 @@ def ATS(
153
  # ====================== GRADIO INTERFACE ======================
154
  with gr.Blocks(
155
  title="Automated Text Summarization System",
156
- theme=gr.themes.Soft(),
157
- css="""
158
- .gradio-container {max-width: 1200px; margin: auto;}
159
- .title {text-align: center; margin-bottom: 10px;}
160
- """
161
  ) as demo:
162
  gr.Markdown(
163
  """
@@ -173,7 +171,6 @@ with gr.Blocks(
173
  placeholder="Paste your long text here (up to several thousand words)...",
174
  lines=12,
175
  max_lines=30,
176
- show_copy_button=True
177
  )
178
 
179
  with gr.Column(scale=1):
@@ -211,7 +208,6 @@ with gr.Blocks(
211
  label="📄 Summary Result",
212
  lines=10,
213
  placeholder="The result will appear here...",
214
- show_copy_button=True
215
  )
216
 
217
  # Connect button click
@@ -245,4 +241,8 @@ with gr.Blocks(
245
 
246
  # Launch the app
247
  if __name__ == "__main__":
248
- demo.launch()
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
  import torch
4
  from transformers import BartForConditionalGeneration, BartTokenizer
 
6
  import numpy as np
7
  import networkx as nx
8
  from typing import List, Dict
9
+ from src.utils.get_model import get_summarizer, get_extractive_model, get_extractive_abstractive
10
  from src.preprocessing.edu_sentences import preprocess_external_text
 
11
  from src.model.baseline_extractive_model import get_trigrams
12
+ from dotenv import load_dotenv
13
+ load_dotenv() # Tải biến môi trường từ file .env
14
+ os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") # Thay bằng token của bạn nếu cần
15
 
16
  REPO_ID_baseline_model = "Reality8081/bart-base"
17
  REPO_ID_baseline_model_edu = "Reality8081/bart-base-edu"
18
+ REPO_ID_baseline_extractive_model = "Reality8081/bart_extractive"
19
+ REPO_ID_baseline_extractive_model_edu = "Reality8081/bart_extractive-edu"
20
  REPO_ID_Extabs_model = "Reality8081/bart-encoder-decoder"
21
  REPO_ID_Extabs_model_edu = "Reality8081/bart-encoder-decoder-edu"
22
 
 
52
  repo_id = REPO_ID_baseline_extractive_model_edu
53
  else:
54
  repo_id = REPO_ID_baseline_extractive_model
55
+ model = get_extractive_model(repo_id=repo_id, device=device)
56
+ model = model.to(torch.float32)
57
+ model.eval()
58
+ with torch.no_grad(), torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu"):
59
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
60
  # Sử dụng Sigmoid đưa logit về khoảng (0, 1) để lấy xác suất
61
+ logits = outputs['logits'].to(torch.float32)
62
+ probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()
63
  segment_scores = []
64
  current_idx = 1 # Bỏ qua token đặc biệt <s> ở đầu chuỗi
65
 
 
111
  repo_id = REPO_ID_Extabs_model_edu
112
  else:
113
  repo_id = REPO_ID_Extabs_model
114
+ model = get_extractive_abstractive(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
115
  with torch.no_grad():
116
  summary_ids = model.generate_summary(
117
  input_ids=input_ids,
 
142
  prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu')
143
 
144
  # Step 2: Chọn model
145
+ if model == "Baseline Model":
146
  result = model_baseline(prepro_dict)
147
  elif model == "Baseline Model with Extractive":
148
  result = model_baseline_extractive(prepro_dict)
 
155
  # ====================== GRADIO INTERFACE ======================
156
  with gr.Blocks(
157
  title="Automated Text Summarization System",
158
+
 
 
 
 
159
  ) as demo:
160
  gr.Markdown(
161
  """
 
171
  placeholder="Paste your long text here (up to several thousand words)...",
172
  lines=12,
173
  max_lines=30,
 
174
  )
175
 
176
  with gr.Column(scale=1):
 
208
  label="📄 Summary Result",
209
  lines=10,
210
  placeholder="The result will appear here...",
 
211
  )
212
 
213
  # Connect button click
 
241
 
242
  # Launch the app
243
  if __name__ == "__main__":
244
+ demo.launch(theme=gr.themes.Soft(),
245
+ css="""
246
+ .gradio-container {max-width: 1200px; margin: auto;}
247
+ .title {text-align: center; margin-bottom: 10px;}
248
+ """)
src/model/baseline_extractive_model.py CHANGED
@@ -2,40 +2,53 @@ import torch
2
  import torch.nn as nn
3
  import numpy as np
4
  from transformers import BartModel, BartTokenizer
 
5
 
6
- class BartExtractiveSummarizer(nn.Module):
7
  def __init__(self, model_name="facebook/bart-large"):
8
  super(BartExtractiveSummarizer, self).__init__()
9
  self.encoder = BartModel.from_pretrained(model_name).encoder
10
  hidden_size = self.encoder.config.hidden_size
11
  self.classifier = nn.Linear(hidden_size, 1)
12
 
 
 
 
13
  def forward(self, input_ids, attention_mask, saliency_mask=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
14
  encoder_outputs = self.encoder(
15
  input_ids=input_ids,
16
  attention_mask=attention_mask
17
  )
18
- hidden_states = encoder_outputs.last_hidden_state
19
- logits = self.classifier(hidden_states).squeeze(-1)
20
 
 
 
 
21
  loss = None
22
  if saliency_mask is not None:
23
  active_loss = attention_mask.view(-1) == 1
24
  active_logits = logits.view(-1)[active_loss]
25
  active_labels = saliency_mask.view(-1)[active_loss].float()
26
-
27
- # --- TỐI ƯU: TỰ ĐỘNG TÍNH CLASS WEIGHT CHO TỪNG BATCH ---
28
  num_pos = active_labels.sum()
29
  num_neg = active_labels.size(0) - num_pos
30
-
31
- if num_pos > 0:
32
- weight = torch.clamp(num_neg / num_pos, max=10.0)
33
- else:
34
- weight = torch.tensor(1.0).to(logits.device)
35
-
36
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=weight)
37
  loss = loss_fct(active_logits, active_labels)
38
-
39
  return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
40
 
41
  def get_trigrams(text: str):
 
2
  import torch.nn as nn
3
  import numpy as np
4
  from transformers import BartModel, BartTokenizer
5
+ from huggingface_hub import PyTorchModelHubMixin
6
 
7
+ class BartExtractiveSummarizer(nn.Module, PyTorchModelHubMixin):
8
  def __init__(self, model_name="facebook/bart-large"):
9
  super(BartExtractiveSummarizer, self).__init__()
10
  self.encoder = BartModel.from_pretrained(model_name).encoder
11
  hidden_size = self.encoder.config.hidden_size
12
  self.classifier = nn.Linear(hidden_size, 1)
13
 
14
+ # Force float32 from the beginning
15
+ self.to(torch.float32)
16
+
17
  def forward(self, input_ids, attention_mask, saliency_mask=None, **kwargs):
18
+ device = next(self.parameters()).device
19
+
20
+ input_ids = input_ids.to(torch.long).to(device)
21
+ attention_mask = attention_mask.to(torch.long).to(device)
22
+ if saliency_mask is not None:
23
+ saliency_mask = saliency_mask.to(torch.float32).to(device)
24
+
25
+ # Extra safety: ensure encoder stays in float32
26
+ if self.encoder.parameters().__next__().dtype != torch.float32:
27
+ self.encoder = self.encoder.to(torch.float32)
28
+
29
  encoder_outputs = self.encoder(
30
  input_ids=input_ids,
31
  attention_mask=attention_mask
32
  )
 
 
33
 
34
+ hidden_states = encoder_outputs.last_hidden_state.float()
35
+ logits = self.classifier(hidden_states).squeeze(-1)
36
+
37
  loss = None
38
  if saliency_mask is not None:
39
  active_loss = attention_mask.view(-1) == 1
40
  active_logits = logits.view(-1)[active_loss]
41
  active_labels = saliency_mask.view(-1)[active_loss].float()
42
+
 
43
  num_pos = active_labels.sum()
44
  num_neg = active_labels.size(0) - num_pos
45
+
46
+ weight = torch.tensor(num_neg / num_pos if num_pos > 0 else 1.0,
47
+ dtype=torch.float32, device=logits.device)
48
+
 
 
49
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=weight)
50
  loss = loss_fct(active_logits, active_labels)
51
+
52
  return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
53
 
54
  def get_trigrams(text: str):
src/model/extabs.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import BartForConditionalGeneration, BartTokenizer
 
4
 
5
- class EXTABSModel(nn.Module):
6
  def __init__(self, model_name="facebook/bart-large"):
7
  super(EXTABSModel, self).__init__()
8
  # Load kiến trúc BART gốc
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import BartForConditionalGeneration, BartTokenizer
4
+ from huggingface_hub import PyTorchModelHubMixin
5
 
6
+ class EXTABSModel(nn.Module, PyTorchModelHubMixin):
7
  def __init__(self, model_name="facebook/bart-large"):
8
  super(EXTABSModel, self).__init__()
9
  # Load kiến trúc BART gốc
src/utils/get_model.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from huggingface_hub import hf_hub_download
5
  from src.model.extabs import EXTABSModel
6
  import gc
7
-
8
  from safetensors.torch import load_file
9
 
10
  active_model_info = {
@@ -42,14 +42,10 @@ def get_extractive_model(repo_id: str, base_model_name: str = "facebook/bart-lar
42
  print(f"Đang tải mô hình Extractive từ repo: {repo_id}...")
43
 
44
  # Khởi tạo khung kiến trúc trống
45
- model = BartExtractiveSummarizer(model_name=base_model_name)
46
-
47
  # Sử dụng hf_hub_download để kéo file trọng số về local cache
48
- model_path = hf_hub_download(repo_id=repo_id, filename="model_state.bin")
49
- state_dict = load_file(model_path)
50
- # Load trọng số vào model
51
- model.load_state_dict(state_dict)
52
-
53
  model.to(device)
54
  model.eval() # Chuyển mô hình sang chế độ inference
55
 
@@ -57,21 +53,18 @@ def get_extractive_model(repo_id: str, base_model_name: str = "facebook/bart-lar
57
  active_model_info["repo_id"] = repo_id
58
  return active_model_info["model"]
59
 
60
- def get_extractive_abstractive(repo_id: str, base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"):
61
  """Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub"""
62
  if active_model_info["repo_id"] != repo_id:
63
  clear_memory() # XÓA SẠCH MODEL CŨ
64
  print(f"Đang tải mô hình Extractive từ repo: {repo_id}...")
65
 
66
  # Khởi tạo khung kiến trúc trống
67
- model = EXTABSModel(model_name=base_model_name)
68
 
69
  # Sử dụng hf_hub_download để kéo file trọng số về local cache
70
- model_path = hf_hub_download(repo_id=repo_id, filename="model_state.bin")
71
- state_dict = load_file(model_path)
72
  # Load trọng số vào model
73
- model.load_state_dict(state_dict)
74
-
75
  model.to(device)
76
  model.eval() # Chuyển mô hình sang chế độ inference
77
 
 
4
  from huggingface_hub import hf_hub_download
5
  from src.model.extabs import EXTABSModel
6
  import gc
7
+ import os
8
  from safetensors.torch import load_file
9
 
10
  active_model_info = {
 
42
  print(f"Đang tải mô hình Extractive từ repo: {repo_id}...")
43
 
44
  # Khởi tạo khung kiến trúc trống
45
+ model = BartExtractiveSummarizer.from_pretrained(repo_id, model_name=base_model_name)
 
46
  # Sử dụng hf_hub_download để kéo file trọng số về local cache
47
+ # Load trọng số vào model
48
+ model = model.to(torch.float32)
 
 
 
49
  model.to(device)
50
  model.eval() # Chuyển mô hình sang chế độ inference
51
 
 
53
  active_model_info["repo_id"] = repo_id
54
  return active_model_info["model"]
55
 
56
+ def get_extractive_abstractive(repo_id: str,base_model_name: str = "facebook/bart-large", device: torch.device = "cpu"):
57
  """Tải và lưu cache mô hình Custom Extractive từ Hugging Face Hub"""
58
  if active_model_info["repo_id"] != repo_id:
59
  clear_memory() # XÓA SẠCH MODEL CŨ
60
  print(f"Đang tải mô hình Extractive từ repo: {repo_id}...")
61
 
62
  # Khởi tạo khung kiến trúc trống
63
+ model = EXTABSModel.from_pretrained(repo_id, model_name=base_model_name)
64
 
65
  # Sử dụng hf_hub_download để kéo file trọng số về local cache
 
 
66
  # Load trọng số vào model
67
+ model = model.to(torch.float32)
 
68
  model.to(device)
69
  model.eval() # Chuyển mô hình sang chế độ inference
70