Spaces:
Running
Running
czjun commited on
Commit ·
8d28a45
1
Parent(s): ac5d6e0
Update README and implement training and evaluation scripts for Chinese summarization model
Browse files- Updated README.md to include new training and evaluation instructions.
- Changed default model name in app.py to `fnlp/bart-base-chinese` and adjusted max source length.
- Added a new endpoint `/summarize-plain` in app.py for plain text summarization.
- Updated requirements.txt to include new dependencies: accelerate, rouge-score, and bert-score.
- Created data_utils.py for loading JSONL data and iterating through summarization examples.
- Implemented evaluate.py for model evaluation with ROUGE and BERTScore metrics.
- Developed train.py for fine-tuning the summarization model with specified parameters.
- Added error handling for missing dependencies in evaluation and training scripts.
- README.md +18 -1
- __pycache__/app.cpython-310.pyc +0 -0
- __pycache__/data_utils.cpython-310.pyc +0 -0
- __pycache__/evaluate.cpython-310.pyc +0 -0
- __pycache__/train.cpython-310.pyc +0 -0
- app.py +32 -9
- data_utils.py +34 -0
- evaluate.py +151 -0
- requirements.txt +3 -0
- train.py +128 -0
README.md
CHANGED
|
@@ -12,4 +12,21 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
|
|
| 12 |
|
| 13 |
To force a specific transformer model in Spaces, set the `MODEL_NAME` environment variable, for example:
|
| 14 |
|
| 15 |
-
`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
To force a specific transformer model in Spaces, set the `MODEL_NAME` environment variable, for example:
|
| 14 |
|
| 15 |
+
`fnlp/bart-base-chinese`
|
| 16 |
+
|
| 17 |
+
## Training and evaluation
|
| 18 |
+
|
| 19 |
+
For local fine-tuning and metric collection:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
python train.py --train-path data/train.jsonl --valid-path data/valid.jsonl --output-dir outputs/bart_cn
|
| 23 |
+
python evaluate.py --test-path data/test.jsonl --model-name outputs/bart_cn --output-csv metrics_report.csv
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
The evaluation script prints and exports:
|
| 27 |
+
|
| 28 |
+
- `ROUGE-L`
|
| 29 |
+
- `BERTScore`
|
| 30 |
+
- `QAFactEval` when an external QAFactEval environment is available
|
| 31 |
+
- length hit rate
|
| 32 |
+
- average latency
|
__pycache__/app.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
|
|
|
__pycache__/data_utils.cpython-310.pyc
ADDED
|
Binary file (1.27 kB). View file
|
|
|
__pycache__/evaluate.cpython-310.pyc
ADDED
|
Binary file (5.44 kB). View file
|
|
|
__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
app.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import List, Optional
|
|
| 7 |
|
| 8 |
from fastapi import FastAPI
|
| 9 |
from fastapi.responses import HTMLResponse
|
|
|
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
|
| 12 |
try:
|
|
@@ -30,8 +31,8 @@ class SummaryOutput:
|
|
| 30 |
|
| 31 |
|
| 32 |
class SummarizationConfig:
|
| 33 |
-
model_name: str = os.getenv("MODEL_NAME", "
|
| 34 |
-
max_source_length: int =
|
| 35 |
max_target_length: int = 160
|
| 36 |
num_beams: int = 4
|
| 37 |
no_repeat_ngram_size: int = 3
|
|
@@ -151,7 +152,7 @@ class HybridSummarizer:
|
|
| 151 |
prompt,
|
| 152 |
return_tensors="pt",
|
| 153 |
truncation=True,
|
| 154 |
-
max_length=
|
| 155 |
)
|
| 156 |
inputs.pop("token_type_ids", None)
|
| 157 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
@@ -209,6 +210,20 @@ def summarize(req: SummarizeRequest):
|
|
| 209 |
)
|
| 210 |
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
@app.get("/")
|
| 213 |
def root():
|
| 214 |
error_note = f"<p>最近一次生成错误:<code>{engine.load_error}</code></p>" if engine.load_error else ""
|
|
@@ -289,11 +304,18 @@ def root():
|
|
| 289 |
border-radius: 6px;
|
| 290 |
}
|
| 291 |
pre {
|
| 292 |
-
background: #
|
| 293 |
-
color: #
|
| 294 |
padding: 16px;
|
| 295 |
border-radius: 12px;
|
| 296 |
overflow-x: auto;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
}
|
| 298 |
.meta {
|
| 299 |
color: #6b7280;
|
|
@@ -319,17 +341,18 @@ def root():
|
|
| 319 |
<div class="guide">
|
| 320 |
<h2>使用指南</h2>
|
| 321 |
<p>1. 点击 <code>打开接口文档</code>,进入 Swagger 页面。</p>
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
"text": "这里放一段较长的中文文本",
|
| 326 |
"target_length": 120
|
| 327 |
}</code></pre>
|
| 328 |
<p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
|
| 329 |
<p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
|
| 330 |
<p>6. 如果接口返回 <code>backend=fallback</code>,请查看响应里的 <code>error</code> 字段,这表示 Transformer 生成阶段失败,系统才会自动切回备用摘要。</p>
|
|
|
|
| 331 |
<div class="meta">
|
| 332 |
-
提示:
|
| 333 |
</div>
|
| 334 |
</div>
|
| 335 |
</div>
|
|
|
|
| 7 |
|
| 8 |
from fastapi import FastAPI
|
| 9 |
from fastapi.responses import HTMLResponse
|
| 10 |
+
from fastapi import Body, Query
|
| 11 |
from pydantic import BaseModel, Field
|
| 12 |
|
| 13 |
try:
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
class SummarizationConfig:
|
| 34 |
+
model_name: str = os.getenv("MODEL_NAME", "fnlp/bart-base-chinese")
|
| 35 |
+
max_source_length: int = 512
|
| 36 |
max_target_length: int = 160
|
| 37 |
num_beams: int = 4
|
| 38 |
no_repeat_ngram_size: int = 3
|
|
|
|
| 152 |
prompt,
|
| 153 |
return_tensors="pt",
|
| 154 |
truncation=True,
|
| 155 |
+
max_length=SummarizationConfig.max_source_length,
|
| 156 |
)
|
| 157 |
inputs.pop("token_type_ids", None)
|
| 158 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
| 210 |
)
|
| 211 |
|
| 212 |
|
| 213 |
+
@app.post("/summarize-plain", response_model=SummarizeResponse)
|
| 214 |
+
def summarize_plain(
|
| 215 |
+
text: str = Body(..., media_type="text/plain", description="直接粘贴原文,支持换行和空格"),
|
| 216 |
+
target_length: int = Query(120, ge=1, description="目标摘要长度"),
|
| 217 |
+
):
|
| 218 |
+
result = engine.summarize(text, target_length=target_length)
|
| 219 |
+
return SummarizeResponse(
|
| 220 |
+
summary=result.summary,
|
| 221 |
+
backend=result.backend,
|
| 222 |
+
target_length=result.used_target_length,
|
| 223 |
+
error=result.error,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
@app.get("/")
|
| 228 |
def root():
|
| 229 |
error_note = f"<p>最近一次生成错误:<code>{engine.load_error}</code></p>" if engine.load_error else ""
|
|
|
|
| 304 |
border-radius: 6px;
|
| 305 |
}
|
| 306 |
pre {
|
| 307 |
+
background: #f8fafc;
|
| 308 |
+
color: #111827;
|
| 309 |
padding: 16px;
|
| 310 |
border-radius: 12px;
|
| 311 |
overflow-x: auto;
|
| 312 |
+
border: 1px solid rgba(148, 163, 184, 0.25);
|
| 313 |
+
}
|
| 314 |
+
pre code {
|
| 315 |
+
background: transparent;
|
| 316 |
+
padding: 0;
|
| 317 |
+
border-radius: 0;
|
| 318 |
+
color: inherit;
|
| 319 |
}
|
| 320 |
.meta {
|
| 321 |
color: #6b7280;
|
|
|
|
| 341 |
<div class="guide">
|
| 342 |
<h2>使用指南</h2>
|
| 343 |
<p>1. 点击 <code>打开接口文档</code>,进入 Swagger 页面。</p>
|
| 344 |
+
<p>2. 找到 <code>POST /summarize</code>,点击 <code>Try it out</code>。</p>
|
| 345 |
+
<p>3. 在请求体中填写文本和目标长度,例如:</p>
|
| 346 |
+
<pre><code>{
|
| 347 |
"text": "这里放一段较长的中文文本",
|
| 348 |
"target_length": 120
|
| 349 |
}</code></pre>
|
| 350 |
<p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p>
|
| 351 |
<p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p>
|
| 352 |
<p>6. 如果接口返回 <code>backend=fallback</code>,请查看响应里的 <code>error</code> 字段,这表示 Transformer 生成阶段失败,系统才会自动切回备用摘要。</p>
|
| 353 |
+
<p>7. 如果原文包含大量换行或空格,建议直接使用 <code>POST /summarize-plain</code>,把正文当作纯文本提交,更适合粘贴文章正文。</p>
|
| 354 |
<div class="meta">
|
| 355 |
+
提示:<code>/summarize</code> 走 JSON,<code>/summarize-plain</code> 走纯文本。前者适合结构化调用,后者适合直接粘贴文章。
|
| 356 |
</div>
|
| 357 |
</div>
|
| 358 |
</div>
|
data_utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Iterable, List
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class SummarizationExample:
|
| 11 |
+
article: str
|
| 12 |
+
summary: str
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_jsonl(path: str | Path) -> List[SummarizationExample]:
|
| 16 |
+
path = Path(path)
|
| 17 |
+
items: List[SummarizationExample] = []
|
| 18 |
+
with path.open("r", encoding="utf-8") as f:
|
| 19 |
+
for line in f:
|
| 20 |
+
line = line.strip()
|
| 21 |
+
if not line:
|
| 22 |
+
continue
|
| 23 |
+
obj = json.loads(line)
|
| 24 |
+
article = obj.get("article") or obj.get("text") or ""
|
| 25 |
+
summary = obj.get("summary") or obj.get("label") or ""
|
| 26 |
+
if article and summary:
|
| 27 |
+
items.append(SummarizationExample(article=article, summary=summary))
|
| 28 |
+
return items
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def iter_pairs(examples: Iterable[SummarizationExample]):
|
| 32 |
+
for ex in examples:
|
| 33 |
+
yield ex.article, ex.summary
|
| 34 |
+
|
evaluate.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from bert_score import score as bertscore
|
| 10 |
+
from rouge_score import rouge_scorer
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 13 |
+
except Exception as exc: # pragma: no cover
|
| 14 |
+
raise SystemExit(
|
| 15 |
+
"Evaluation requires bert-score, rouge-score, torch and transformers. Install dependencies first."
|
| 16 |
+
) from exc
|
| 17 |
+
|
| 18 |
+
from data_utils import load_jsonl
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
parser = argparse.ArgumentParser(description="Evaluate summarization models")
|
| 23 |
+
parser.add_argument("--test-path", required=True)
|
| 24 |
+
parser.add_argument("--model-name", default="fnlp/bart-base-chinese")
|
| 25 |
+
parser.add_argument("--max-source-length", type=int, default=512)
|
| 26 |
+
parser.add_argument("--target-length", type=int, default=120)
|
| 27 |
+
parser.add_argument("--tolerance", type=float, default=0.2)
|
| 28 |
+
parser.add_argument("--output-csv", default="metrics_report.csv")
|
| 29 |
+
parser.add_argument("--qafacteval-model-folder", default=None)
|
| 30 |
+
return parser.parse_args()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def length_hit(text: str, target_length: int, tolerance: float) -> bool:
|
| 34 |
+
low = int(target_length * (1 - tolerance))
|
| 35 |
+
high = int(target_length * (1 + tolerance))
|
| 36 |
+
return low <= len(text) <= high
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def try_qafacteval(model_folder: str | None, sources, preds):
|
| 40 |
+
if not model_folder:
|
| 41 |
+
return [None] * len(preds)
|
| 42 |
+
try:
|
| 43 |
+
from qafacteval import QAFactEval
|
| 44 |
+
except Exception:
|
| 45 |
+
return [None] * len(preds)
|
| 46 |
+
metric = QAFactEval(
|
| 47 |
+
lerc_quip_path=f"{model_folder}/quip-512-mocha",
|
| 48 |
+
generation_model_path=f"{model_folder}/generation/model.tar.gz",
|
| 49 |
+
answering_model_dir=f"{model_folder}/answering",
|
| 50 |
+
lerc_model_path=f"{model_folder}/lerc/model.tar.gz",
|
| 51 |
+
lerc_pretrained_model_path=f"{model_folder}/lerc/pretraining.tar.gz",
|
| 52 |
+
cuda_device=0 if torch.cuda.is_available() else -1,
|
| 53 |
+
use_lerc_quip=True,
|
| 54 |
+
verbose=False,
|
| 55 |
+
generation_batch_size=8,
|
| 56 |
+
answering_batch_size=8,
|
| 57 |
+
lerc_batch_size=4,
|
| 58 |
+
)
|
| 59 |
+
results = metric.score_batch(list(sources), [[p] for p in preds], return_qa_pairs=True)
|
| 60 |
+
scores = []
|
| 61 |
+
for row in results:
|
| 62 |
+
item = row[0]["qa-eval"].get("lerc_quip")
|
| 63 |
+
scores.append(item)
|
| 64 |
+
return scores
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main():
|
| 68 |
+
args = parse_args()
|
| 69 |
+
examples = load_jsonl(args.test_path)
|
| 70 |
+
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 72 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
| 73 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 74 |
+
model.to(device)
|
| 75 |
+
model.eval()
|
| 76 |
+
|
| 77 |
+
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
|
| 78 |
+
sources = []
|
| 79 |
+
refs = []
|
| 80 |
+
preds = []
|
| 81 |
+
times_ms = []
|
| 82 |
+
length_flags = []
|
| 83 |
+
|
| 84 |
+
for ex in examples:
|
| 85 |
+
inputs = tokenizer(
|
| 86 |
+
ex.article,
|
| 87 |
+
return_tensors="pt",
|
| 88 |
+
truncation=True,
|
| 89 |
+
max_length=args.max_source_length,
|
| 90 |
+
)
|
| 91 |
+
inputs.pop("token_type_ids", None)
|
| 92 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 93 |
+
start = time.perf_counter()
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
out = model.generate(
|
| 96 |
+
**inputs,
|
| 97 |
+
max_new_tokens=max(48, min(192, int(args.target_length * 1.1))),
|
| 98 |
+
num_beams=4,
|
| 99 |
+
no_repeat_ngram_size=3,
|
| 100 |
+
length_penalty=1.0,
|
| 101 |
+
early_stopping=True,
|
| 102 |
+
)
|
| 103 |
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
| 104 |
+
pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
|
| 105 |
+
|
| 106 |
+
sources.append(ex.article)
|
| 107 |
+
refs.append(ex.summary)
|
| 108 |
+
preds.append(pred)
|
| 109 |
+
times_ms.append(elapsed_ms)
|
| 110 |
+
length_flags.append(length_hit(pred, args.target_length, args.tolerance))
|
| 111 |
+
|
| 112 |
+
rouge_ls = [scorer.score(ref, pred)["rougeL"].fmeasure for ref, pred in zip(refs, preds)]
|
| 113 |
+
P, R, F1 = bertscore(preds, refs, lang="zh", verbose=False)
|
| 114 |
+
qafacteval_scores = try_qafacteval(args.qafacteval_model_folder, sources, preds)
|
| 115 |
+
|
| 116 |
+
rouge_l = sum(rouge_ls) / max(1, len(rouge_ls))
|
| 117 |
+
bert_f1 = float(F1.mean().item()) if hasattr(F1.mean(), "item") else float(F1.mean())
|
| 118 |
+
length_rate = sum(1 for v in length_flags if v) / max(1, len(length_flags))
|
| 119 |
+
avg_latency = sum(times_ms) / max(1, len(times_ms))
|
| 120 |
+
qafacteval_valid = [s for s in qafacteval_scores if s is not None]
|
| 121 |
+
qafacteval_avg = sum(qafacteval_valid) / len(qafacteval_valid) if qafacteval_valid else None
|
| 122 |
+
|
| 123 |
+
print(f"ROUGE-L: {rouge_l:.4f}")
|
| 124 |
+
print(f"BERTScore: {bert_f1:.4f}")
|
| 125 |
+
print(f"Length Hit Rate: {length_rate:.4f}")
|
| 126 |
+
print(f"Avg Latency(ms): {avg_latency:.2f}")
|
| 127 |
+
if qafacteval_avg is not None:
|
| 128 |
+
print(f"QAFactEval: {qafacteval_avg:.4f}")
|
| 129 |
+
else:
|
| 130 |
+
print("QAFactEval: N/A")
|
| 131 |
+
|
| 132 |
+
out_path = Path(args.output_csv)
|
| 133 |
+
with out_path.open("w", newline="", encoding="utf-8") as f:
|
| 134 |
+
writer = csv.writer(f)
|
| 135 |
+
writer.writerow(["model", "rouge_l", "bertscore", "qafacteval", "length_hit_rate", "avg_latency_ms"])
|
| 136 |
+
writer.writerow(
|
| 137 |
+
[
|
| 138 |
+
args.model_name,
|
| 139 |
+
f"{rouge_l:.4f}",
|
| 140 |
+
f"{bert_f1:.4f}",
|
| 141 |
+
f"{qafacteval_avg:.4f}" if qafacteval_avg is not None else "",
|
| 142 |
+
f"{length_rate:.4f}",
|
| 143 |
+
f"{avg_latency:.2f}",
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
print(f"saved metrics to {out_path}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
main()
|
| 151 |
+
|
requirements.txt
CHANGED
|
@@ -5,3 +5,6 @@ transformers>=4.41.0
|
|
| 5 |
sentencepiece>=0.2.0
|
| 6 |
torch>=2.1.0
|
| 7 |
protobuf>=4.25.0
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
sentencepiece>=0.2.0
|
| 6 |
torch>=2.1.0
|
| 7 |
protobuf>=4.25.0
|
| 8 |
+
accelerate>=0.30.0
|
| 9 |
+
rouge-score>=0.1.2
|
| 10 |
+
bert-score>=0.3.13
|
train.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from dataclasses import asdict
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from transformers import (
|
| 11 |
+
AutoModelForSeq2SeqLM,
|
| 12 |
+
AutoTokenizer,
|
| 13 |
+
DataCollatorForSeq2Seq,
|
| 14 |
+
Seq2SeqTrainer,
|
| 15 |
+
Seq2SeqTrainingArguments,
|
| 16 |
+
)
|
| 17 |
+
except Exception as exc: # pragma: no cover
|
| 18 |
+
raise SystemExit(
|
| 19 |
+
"Training requires torch, transformers and accelerate. Install dependencies first."
|
| 20 |
+
) from exc
|
| 21 |
+
|
| 22 |
+
from data_utils import load_jsonl
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class JsonlSeq2SeqDataset(Dataset):
|
| 26 |
+
def __init__(self, path, tokenizer, max_source_length: int, max_target_length: int):
|
| 27 |
+
self.examples = load_jsonl(path)
|
| 28 |
+
self.tokenizer = tokenizer
|
| 29 |
+
self.max_source_length = max_source_length
|
| 30 |
+
self.max_target_length = max_target_length
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.examples)
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, idx):
|
| 36 |
+
ex = self.examples[idx]
|
| 37 |
+
model_inputs = self.tokenizer(
|
| 38 |
+
ex.article,
|
| 39 |
+
max_length=self.max_source_length,
|
| 40 |
+
truncation=True,
|
| 41 |
+
)
|
| 42 |
+
labels = self.tokenizer(
|
| 43 |
+
text_target=ex.summary,
|
| 44 |
+
max_length=self.max_target_length,
|
| 45 |
+
truncation=True,
|
| 46 |
+
)
|
| 47 |
+
model_inputs["labels"] = labels["input_ids"]
|
| 48 |
+
return model_inputs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def parse_args():
|
| 52 |
+
parser = argparse.ArgumentParser(description="Fine-tune a Chinese seq2seq summarization model")
|
| 53 |
+
parser.add_argument("--train-path", required=True)
|
| 54 |
+
parser.add_argument("--valid-path", default=None)
|
| 55 |
+
parser.add_argument("--output-dir", required=True)
|
| 56 |
+
parser.add_argument("--model-name", default="fnlp/bart-base-chinese")
|
| 57 |
+
parser.add_argument("--max-source-length", type=int, default=512)
|
| 58 |
+
parser.add_argument("--max-target-length", type=int, default=128)
|
| 59 |
+
parser.add_argument("--num-train-epochs", type=float, default=3.0)
|
| 60 |
+
parser.add_argument("--train-batch-size", type=int, default=2)
|
| 61 |
+
parser.add_argument("--eval-batch-size", type=int, default=2)
|
| 62 |
+
parser.add_argument("--learning-rate", type=float, default=3e-5)
|
| 63 |
+
parser.add_argument("--logging-steps", type=int, default=25)
|
| 64 |
+
parser.add_argument("--save-steps", type=int, default=200)
|
| 65 |
+
return parser.parse_args()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
args = parse_args()
|
| 70 |
+
output_dir = Path(args.output_dir)
|
| 71 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 74 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
| 75 |
+
|
| 76 |
+
train_dataset = JsonlSeq2SeqDataset(
|
| 77 |
+
args.train_path,
|
| 78 |
+
tokenizer,
|
| 79 |
+
max_source_length=args.max_source_length,
|
| 80 |
+
max_target_length=args.max_target_length,
|
| 81 |
+
)
|
| 82 |
+
eval_dataset = (
|
| 83 |
+
JsonlSeq2SeqDataset(
|
| 84 |
+
args.valid_path,
|
| 85 |
+
tokenizer,
|
| 86 |
+
max_source_length=args.max_source_length,
|
| 87 |
+
max_target_length=args.max_target_length,
|
| 88 |
+
)
|
| 89 |
+
if args.valid_path
|
| 90 |
+
else None
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
|
| 94 |
+
|
| 95 |
+
training_args = Seq2SeqTrainingArguments(
|
| 96 |
+
output_dir=str(output_dir),
|
| 97 |
+
learning_rate=args.learning_rate,
|
| 98 |
+
per_device_train_batch_size=args.train_batch_size,
|
| 99 |
+
per_device_eval_batch_size=args.eval_batch_size,
|
| 100 |
+
predict_with_generate=True,
|
| 101 |
+
num_train_epochs=args.num_train_epochs,
|
| 102 |
+
logging_steps=args.logging_steps,
|
| 103 |
+
save_steps=args.save_steps,
|
| 104 |
+
save_total_limit=2,
|
| 105 |
+
evaluation_strategy="steps" if eval_dataset else "no",
|
| 106 |
+
eval_steps=args.save_steps if eval_dataset else None,
|
| 107 |
+
fp16=torch.cuda.is_available(),
|
| 108 |
+
report_to=[],
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
trainer = Seq2SeqTrainer(
|
| 112 |
+
model=model,
|
| 113 |
+
args=training_args,
|
| 114 |
+
train_dataset=train_dataset,
|
| 115 |
+
eval_dataset=eval_dataset,
|
| 116 |
+
tokenizer=tokenizer,
|
| 117 |
+
data_collator=data_collator,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
trainer.train()
|
| 121 |
+
trainer.save_model(str(output_dir))
|
| 122 |
+
tokenizer.save_pretrained(str(output_dir))
|
| 123 |
+
print(f"saved to {output_dir}")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
| 128 |
+
|