czjun commited on
Commit
8d28a45
·
1 Parent(s): ac5d6e0

Update README and implement training and evaluation scripts for Chinese summarization model

Browse files

- Updated README.md to include new training and evaluation instructions.
- Changed default model name in app.py to `fnlp/bart-base-chinese` and adjusted max source length.
- Added a new endpoint `/summarize-plain` in app.py for plain text summarization.
- Updated requirements.txt to include new dependencies: accelerate, rouge-score, and bert-score.
- Created data_utils.py for loading JSONL data and iterating through summarization examples.
- Implemented evaluate.py for model evaluation with ROUGE and BERTScore metrics.
- Developed train.py for fine-tuning the summarization model with specified parameters.
- Added error handling for missing dependencies in evaluation and training scripts.

README.md CHANGED
@@ -12,4 +12,21 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
12
 
13
  To force a specific transformer model in Spaces, set the `MODEL_NAME` environment variable, for example:
14
 
15
- `IDEA-CCNL/Randeng-T5-Char-57M-MultiTask-Chinese`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  To force a specific transformer model in Spaces, set the `MODEL_NAME` environment variable, for example:
14
 
15
+ `fnlp/bart-base-chinese`
16
+
17
+ ## Training and evaluation
18
+
19
+ For local fine-tuning and metric collection:
20
+
21
+ ```bash
22
+ python train.py --train-path data/train.jsonl --valid-path data/valid.jsonl --output-dir outputs/bart_cn
23
+ python evaluate.py --test-path data/test.jsonl --model-name outputs/bart_cn --output-csv metrics_report.csv
24
+ ```
25
+
26
+ The evaluation script prints and exports:
27
+
28
+ - `ROUGE-L`
29
+ - `BERTScore`
30
+ - `QAFactEval` when an external QAFactEval environment is available
31
+ - length hit rate
32
+ - average latency
__pycache__/app.cpython-310.pyc CHANGED
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
 
__pycache__/data_utils.cpython-310.pyc ADDED
Binary file (1.27 kB). View file
 
__pycache__/evaluate.cpython-310.pyc ADDED
Binary file (5.44 kB). View file
 
__pycache__/train.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
app.py CHANGED
@@ -7,6 +7,7 @@ from typing import List, Optional
7
 
8
  from fastapi import FastAPI
9
  from fastapi.responses import HTMLResponse
 
10
  from pydantic import BaseModel, Field
11
 
12
  try:
@@ -30,8 +31,8 @@ class SummaryOutput:
30
 
31
 
32
  class SummarizationConfig:
33
- model_name: str = os.getenv("MODEL_NAME", "IDEA-CCNL/Randeng-T5-Char-57M-MultiTask-Chinese")
34
- max_source_length: int = 1024
35
  max_target_length: int = 160
36
  num_beams: int = 4
37
  no_repeat_ngram_size: int = 3
@@ -151,7 +152,7 @@ class HybridSummarizer:
151
  prompt,
152
  return_tensors="pt",
153
  truncation=True,
154
- max_length=512,
155
  )
156
  inputs.pop("token_type_ids", None)
157
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -209,6 +210,20 @@ def summarize(req: SummarizeRequest):
209
  )
210
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  @app.get("/")
213
  def root():
214
  error_note = f"<p>最近一次生成错误:<code>{engine.load_error}</code></p>" if engine.load_error else ""
@@ -289,11 +304,18 @@ def root():
289
  border-radius: 6px;
290
  }
291
  pre {
292
- background: #0f172a;
293
- color: #e2e8f0;
294
  padding: 16px;
295
  border-radius: 12px;
296
  overflow-x: auto;
 
 
 
 
 
 
 
297
  }
298
  .meta {
299
  color: #6b7280;
@@ -319,17 +341,18 @@ def root():
319
  <div class="guide">
320
  <h2>使用指南</h2>
321
  <p>1. 点击 <code>打开接口文档</code>,进入 Swagger 页面。</p>
322
- <p>2. 找到 <code>POST /summarize</code>,点击 <code>Try it out</code>。</p>
323
- <p>3. 在请求体中填写文本和目标长度,例如:</p>
324
- <pre><code>{
325
  "text": "这里放一段较长的中文文本",
326
  "target_length": 120
327
  }</code></pre>
328
  <p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
329
  <p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
330
  <p>6. 如果接口返回 <code>backend=fallback</code>,请查看响应里的 <code>error</code> 字段,这表示 Transformer 生成阶段失败,系统才会自动切回备用摘要。</p>
 
331
  <div class="meta">
332
- 提示:如果文本里有换行请确保是法 JSON。建议直接在 Swagger 页面提交,避免手写 JSON 出错
333
  </div>
334
  </div>
335
  </div>
 
7
 
8
  from fastapi import FastAPI
9
  from fastapi.responses import HTMLResponse
10
+ from fastapi import Body, Query
11
  from pydantic import BaseModel, Field
12
 
13
  try:
 
31
 
32
 
33
  class SummarizationConfig:
34
+ model_name: str = os.getenv("MODEL_NAME", "fnlp/bart-base-chinese")
35
+ max_source_length: int = 512
36
  max_target_length: int = 160
37
  num_beams: int = 4
38
  no_repeat_ngram_size: int = 3
 
152
  prompt,
153
  return_tensors="pt",
154
  truncation=True,
155
+ max_length=SummarizationConfig.max_source_length,
156
  )
157
  inputs.pop("token_type_ids", None)
158
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
210
  )
211
 
212
 
213
+ @app.post("/summarize-plain", response_model=SummarizeResponse)
214
+ def summarize_plain(
215
+ text: str = Body(..., media_type="text/plain", description="直接粘贴原文,支持换行和空格"),
216
+ target_length: int = Query(120, ge=1, description="目标摘要长度"),
217
+ ):
218
+ result = engine.summarize(text, target_length=target_length)
219
+ return SummarizeResponse(
220
+ summary=result.summary,
221
+ backend=result.backend,
222
+ target_length=result.used_target_length,
223
+ error=result.error,
224
+ )
225
+
226
+
227
  @app.get("/")
228
  def root():
229
  error_note = f"<p>最近一次生成错误:<code>{engine.load_error}</code></p>" if engine.load_error else ""
 
304
  border-radius: 6px;
305
  }
306
  pre {
307
+ background: #f8fafc;
308
+ color: #111827;
309
  padding: 16px;
310
  border-radius: 12px;
311
  overflow-x: auto;
312
+ border: 1px solid rgba(148, 163, 184, 0.25);
313
+ }
314
+ pre code {
315
+ background: transparent;
316
+ padding: 0;
317
+ border-radius: 0;
318
+ color: inherit;
319
  }
320
  .meta {
321
  color: #6b7280;
 
341
  <div class="guide">
342
  <h2>使用指南</h2>
343
  <p>1. 点击 <code>打开接口文档</code>,进入 Swagger 页面。</p>
344
+ <p>2. 找到 <code>POST /summarize</code>,点击 <code>Try it out</code>。</p>
345
+ <p>3. 在请求体中填写文本和目标长度,例如:</p>
346
+ <pre><code>{
347
  "text": "这里放一段较长的中文文本",
348
  "target_length": 120
349
  }</code></pre>
350
  <p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
351
  <p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
352
  <p>6. 如果接口返回 <code>backend=fallback</code>,请查看响应里的 <code>error</code> 字段,这表示 Transformer 生成阶段失败,系统才会自动切回备用摘要。</p>
353
+ <p>7. 如果原文包含大量换行或空格,建议直接使用 <code>POST /summarize-plain</code>,把正文当作纯文本提交,更适合粘贴文章正文。</p>
354
  <div class="meta">
355
+ 提示:<code>/summarize</code> 走 JSON,<code>/summarize-plain</code> 走纯文本。前者适合结构化调用后者适合直接粘贴文章
356
  </div>
357
  </div>
358
  </div>
data_utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Iterable, List
7
+
8
+
9
+ @dataclass
10
+ class SummarizationExample:
11
+ article: str
12
+ summary: str
13
+
14
+
15
+ def load_jsonl(path: str | Path) -> List[SummarizationExample]:
16
+ path = Path(path)
17
+ items: List[SummarizationExample] = []
18
+ with path.open("r", encoding="utf-8") as f:
19
+ for line in f:
20
+ line = line.strip()
21
+ if not line:
22
+ continue
23
+ obj = json.loads(line)
24
+ article = obj.get("article") or obj.get("text") or ""
25
+ summary = obj.get("summary") or obj.get("label") or ""
26
+ if article and summary:
27
+ items.append(SummarizationExample(article=article, summary=summary))
28
+ return items
29
+
30
+
31
+ def iter_pairs(examples: Iterable[SummarizationExample]):
32
+ for ex in examples:
33
+ yield ex.article, ex.summary
34
+
evaluate.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import csv
5
+ import time
6
+ from pathlib import Path
7
+
8
+ try:
9
+ from bert_score import score as bertscore
10
+ from rouge_score import rouge_scorer
11
+ import torch
12
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
13
+ except Exception as exc: # pragma: no cover
14
+ raise SystemExit(
15
+ "Evaluation requires bert-score, rouge-score, torch and transformers. Install dependencies first."
16
+ ) from exc
17
+
18
+ from data_utils import load_jsonl
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description="Evaluate summarization models")
23
+ parser.add_argument("--test-path", required=True)
24
+ parser.add_argument("--model-name", default="fnlp/bart-base-chinese")
25
+ parser.add_argument("--max-source-length", type=int, default=512)
26
+ parser.add_argument("--target-length", type=int, default=120)
27
+ parser.add_argument("--tolerance", type=float, default=0.2)
28
+ parser.add_argument("--output-csv", default="metrics_report.csv")
29
+ parser.add_argument("--qafacteval-model-folder", default=None)
30
+ return parser.parse_args()
31
+
32
+
33
+ def length_hit(text: str, target_length: int, tolerance: float) -> bool:
34
+ low = int(target_length * (1 - tolerance))
35
+ high = int(target_length * (1 + tolerance))
36
+ return low <= len(text) <= high
37
+
38
+
39
+ def try_qafacteval(model_folder: str | None, sources, preds):
40
+ if not model_folder:
41
+ return [None] * len(preds)
42
+ try:
43
+ from qafacteval import QAFactEval
44
+ except Exception:
45
+ return [None] * len(preds)
46
+ metric = QAFactEval(
47
+ lerc_quip_path=f"{model_folder}/quip-512-mocha",
48
+ generation_model_path=f"{model_folder}/generation/model.tar.gz",
49
+ answering_model_dir=f"{model_folder}/answering",
50
+ lerc_model_path=f"{model_folder}/lerc/model.tar.gz",
51
+ lerc_pretrained_model_path=f"{model_folder}/lerc/pretraining.tar.gz",
52
+ cuda_device=0 if torch.cuda.is_available() else -1,
53
+ use_lerc_quip=True,
54
+ verbose=False,
55
+ generation_batch_size=8,
56
+ answering_batch_size=8,
57
+ lerc_batch_size=4,
58
+ )
59
+ results = metric.score_batch(list(sources), [[p] for p in preds], return_qa_pairs=True)
60
+ scores = []
61
+ for row in results:
62
+ item = row[0]["qa-eval"].get("lerc_quip")
63
+ scores.append(item)
64
+ return scores
65
+
66
+
67
+ def main():
68
+ args = parse_args()
69
+ examples = load_jsonl(args.test_path)
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
72
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
73
+ device = "cuda" if torch.cuda.is_available() else "cpu"
74
+ model.to(device)
75
+ model.eval()
76
+
77
+ scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
78
+ sources = []
79
+ refs = []
80
+ preds = []
81
+ times_ms = []
82
+ length_flags = []
83
+
84
+ for ex in examples:
85
+ inputs = tokenizer(
86
+ ex.article,
87
+ return_tensors="pt",
88
+ truncation=True,
89
+ max_length=args.max_source_length,
90
+ )
91
+ inputs.pop("token_type_ids", None)
92
+ inputs = {k: v.to(device) for k, v in inputs.items()}
93
+ start = time.perf_counter()
94
+ with torch.no_grad():
95
+ out = model.generate(
96
+ **inputs,
97
+ max_new_tokens=max(48, min(192, int(args.target_length * 1.1))),
98
+ num_beams=4,
99
+ no_repeat_ngram_size=3,
100
+ length_penalty=1.0,
101
+ early_stopping=True,
102
+ )
103
+ elapsed_ms = (time.perf_counter() - start) * 1000
104
+ pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
105
+
106
+ sources.append(ex.article)
107
+ refs.append(ex.summary)
108
+ preds.append(pred)
109
+ times_ms.append(elapsed_ms)
110
+ length_flags.append(length_hit(pred, args.target_length, args.tolerance))
111
+
112
+ rouge_ls = [scorer.score(ref, pred)["rougeL"].fmeasure for ref, pred in zip(refs, preds)]
113
+ P, R, F1 = bertscore(preds, refs, lang="zh", verbose=False)
114
+ qafacteval_scores = try_qafacteval(args.qafacteval_model_folder, sources, preds)
115
+
116
+ rouge_l = sum(rouge_ls) / max(1, len(rouge_ls))
117
+ bert_f1 = float(F1.mean().item()) if hasattr(F1.mean(), "item") else float(F1.mean())
118
+ length_rate = sum(1 for v in length_flags if v) / max(1, len(length_flags))
119
+ avg_latency = sum(times_ms) / max(1, len(times_ms))
120
+ qafacteval_valid = [s for s in qafacteval_scores if s is not None]
121
+ qafacteval_avg = sum(qafacteval_valid) / len(qafacteval_valid) if qafacteval_valid else None
122
+
123
+ print(f"ROUGE-L: {rouge_l:.4f}")
124
+ print(f"BERTScore: {bert_f1:.4f}")
125
+ print(f"Length Hit Rate: {length_rate:.4f}")
126
+ print(f"Avg Latency(ms): {avg_latency:.2f}")
127
+ if qafacteval_avg is not None:
128
+ print(f"QAFactEval: {qafacteval_avg:.4f}")
129
+ else:
130
+ print("QAFactEval: N/A")
131
+
132
+ out_path = Path(args.output_csv)
133
+ with out_path.open("w", newline="", encoding="utf-8") as f:
134
+ writer = csv.writer(f)
135
+ writer.writerow(["model", "rouge_l", "bertscore", "qafacteval", "length_hit_rate", "avg_latency_ms"])
136
+ writer.writerow(
137
+ [
138
+ args.model_name,
139
+ f"{rouge_l:.4f}",
140
+ f"{bert_f1:.4f}",
141
+ f"{qafacteval_avg:.4f}" if qafacteval_avg is not None else "",
142
+ f"{length_rate:.4f}",
143
+ f"{avg_latency:.2f}",
144
+ ]
145
+ )
146
+ print(f"saved metrics to {out_path}")
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()
151
+
requirements.txt CHANGED
@@ -5,3 +5,6 @@ transformers>=4.41.0
5
  sentencepiece>=0.2.0
6
  torch>=2.1.0
7
  protobuf>=4.25.0
 
 
 
 
5
  sentencepiece>=0.2.0
6
  torch>=2.1.0
7
  protobuf>=4.25.0
8
+ accelerate>=0.30.0
9
+ rouge-score>=0.1.2
10
+ bert-score>=0.3.13
train.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from dataclasses import asdict
5
+ from pathlib import Path
6
+
7
+ try:
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from transformers import (
11
+ AutoModelForSeq2SeqLM,
12
+ AutoTokenizer,
13
+ DataCollatorForSeq2Seq,
14
+ Seq2SeqTrainer,
15
+ Seq2SeqTrainingArguments,
16
+ )
17
+ except Exception as exc: # pragma: no cover
18
+ raise SystemExit(
19
+ "Training requires torch, transformers and accelerate. Install dependencies first."
20
+ ) from exc
21
+
22
+ from data_utils import load_jsonl
23
+
24
+
25
+ class JsonlSeq2SeqDataset(Dataset):
26
+ def __init__(self, path, tokenizer, max_source_length: int, max_target_length: int):
27
+ self.examples = load_jsonl(path)
28
+ self.tokenizer = tokenizer
29
+ self.max_source_length = max_source_length
30
+ self.max_target_length = max_target_length
31
+
32
+ def __len__(self):
33
+ return len(self.examples)
34
+
35
+ def __getitem__(self, idx):
36
+ ex = self.examples[idx]
37
+ model_inputs = self.tokenizer(
38
+ ex.article,
39
+ max_length=self.max_source_length,
40
+ truncation=True,
41
+ )
42
+ labels = self.tokenizer(
43
+ text_target=ex.summary,
44
+ max_length=self.max_target_length,
45
+ truncation=True,
46
+ )
47
+ model_inputs["labels"] = labels["input_ids"]
48
+ return model_inputs
49
+
50
+
51
+ def parse_args():
52
+ parser = argparse.ArgumentParser(description="Fine-tune a Chinese seq2seq summarization model")
53
+ parser.add_argument("--train-path", required=True)
54
+ parser.add_argument("--valid-path", default=None)
55
+ parser.add_argument("--output-dir", required=True)
56
+ parser.add_argument("--model-name", default="fnlp/bart-base-chinese")
57
+ parser.add_argument("--max-source-length", type=int, default=512)
58
+ parser.add_argument("--max-target-length", type=int, default=128)
59
+ parser.add_argument("--num-train-epochs", type=float, default=3.0)
60
+ parser.add_argument("--train-batch-size", type=int, default=2)
61
+ parser.add_argument("--eval-batch-size", type=int, default=2)
62
+ parser.add_argument("--learning-rate", type=float, default=3e-5)
63
+ parser.add_argument("--logging-steps", type=int, default=25)
64
+ parser.add_argument("--save-steps", type=int, default=200)
65
+ return parser.parse_args()
66
+
67
+
68
+ def main():
69
+ args = parse_args()
70
+ output_dir = Path(args.output_dir)
71
+ output_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
74
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
75
+
76
+ train_dataset = JsonlSeq2SeqDataset(
77
+ args.train_path,
78
+ tokenizer,
79
+ max_source_length=args.max_source_length,
80
+ max_target_length=args.max_target_length,
81
+ )
82
+ eval_dataset = (
83
+ JsonlSeq2SeqDataset(
84
+ args.valid_path,
85
+ tokenizer,
86
+ max_source_length=args.max_source_length,
87
+ max_target_length=args.max_target_length,
88
+ )
89
+ if args.valid_path
90
+ else None
91
+ )
92
+
93
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
94
+
95
+ training_args = Seq2SeqTrainingArguments(
96
+ output_dir=str(output_dir),
97
+ learning_rate=args.learning_rate,
98
+ per_device_train_batch_size=args.train_batch_size,
99
+ per_device_eval_batch_size=args.eval_batch_size,
100
+ predict_with_generate=True,
101
+ num_train_epochs=args.num_train_epochs,
102
+ logging_steps=args.logging_steps,
103
+ save_steps=args.save_steps,
104
+ save_total_limit=2,
105
+ evaluation_strategy="steps" if eval_dataset else "no",
106
+ eval_steps=args.save_steps if eval_dataset else None,
107
+ fp16=torch.cuda.is_available(),
108
+ report_to=[],
109
+ )
110
+
111
+ trainer = Seq2SeqTrainer(
112
+ model=model,
113
+ args=training_args,
114
+ train_dataset=train_dataset,
115
+ eval_dataset=eval_dataset,
116
+ tokenizer=tokenizer,
117
+ data_collator=data_collator,
118
+ )
119
+
120
+ trainer.train()
121
+ trainer.save_model(str(output_dir))
122
+ tokenizer.save_pretrained(str(output_dir))
123
+ print(f"saved to {output_dir}")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
128
+