Spaces:
Running
Running
czjun
Update README and implement training and evaluation scripts for Chinese summarization model
8d28a45 | from __future__ import annotations | |
| import logging | |
| import os | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse | |
| from fastapi import Body, Query | |
| from pydantic import BaseModel, Field | |
| try: | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| except Exception: # pragma: no cover | |
| torch = None | |
| AutoModelForSeq2SeqLM = None | |
| AutoTokenizer = None | |
| logger = logging.getLogger(__name__) | |
| class SummaryOutput: | |
| summary: str | |
| backend: str | |
| used_target_length: Optional[int] | |
| error: Optional[str] = None | |
| class SummarizationConfig: | |
| model_name: str = os.getenv("MODEL_NAME", "fnlp/bart-base-chinese") | |
| max_source_length: int = 512 | |
| max_target_length: int = 160 | |
| num_beams: int = 4 | |
| no_repeat_ngram_size: int = 3 | |
| length_penalty: float = 1.0 | |
| fallback_sentences: int = 3 | |
| def normalize_text(text: str) -> str: | |
| return " ".join(text.replace("\u3000", " ").split()) | |
| def split_sentences(text: str) -> List[str]: | |
| import re | |
| parts = re.split(r"(?<=[。!?!?;;])\s*", text) | |
| return [p.strip() for p in parts if p.strip()] | |
| def tokenize(text: str) -> List[str]: | |
| import re | |
| return re.findall(r"[\u4e00-\u9fff]+|[A-Za-z0-9]+", text.lower()) | |
| class SimpleExtractiveSummarizer: | |
| def __init__(self, max_sentences: int = 3): | |
| self.max_sentences = max_sentences | |
| def summarize(self, text: str, target_length: int | None = None) -> str: | |
| sentences = split_sentences(text) | |
| if not sentences: | |
| return "" | |
| if len(sentences) == 1: | |
| return sentences[0] | |
| freq = {} | |
| for sentence in sentences: | |
| for token in tokenize(sentence): | |
| freq[token] = freq.get(token, 0) + 1 | |
| scored = [] | |
| for idx, sentence in enumerate(sentences): | |
| tokens = tokenize(sentence) | |
| score = sum(freq.get(token, 0) for token in tokens) / max(1, len(tokens)) | |
| scored.append((score, idx, sentence)) | |
| scored.sort(key=lambda item: (-item[0], item[1])) | |
| selected = sorted(scored[: self.max_sentences], key=lambda item: item[1]) | |
| kept: List[str] = [] | |
| total = 0 | |
| for _, _, sentence in selected: | |
| if target_length is not None and kept and total + len(sentence) > target_length: | |
| break | |
| kept.append(sentence) | |
| total += len(sentence) | |
| return "".join(kept or [selected[0][2]]) | |
| class HybridSummarizer: | |
| def __init__(self, model_name: str | None = None): | |
| self.model_name = os.getenv("MODEL_NAME", model_name or SummarizationConfig.model_name) | |
| self.backend_name = "fallback" | |
| self.tokenizer = None | |
| self.model = None | |
| self.fallback = SimpleExtractiveSummarizer() | |
| self.device = "cpu" | |
| self.load_error: str | None = None | |
| self._try_load_transformer() | |
| def _try_load_transformer(self) -> None: | |
| if AutoTokenizer is None or AutoModelForSeq2SeqLM is None or torch is None: | |
| self.load_error = "torch/transformers not installed" | |
| return | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.backend_name = "transformer" | |
| self.load_error = None | |
| except Exception as exc: | |
| self.load_error = f"{type(exc).__name__}: {exc}" | |
| logger.exception("Failed to load transformer model: %s", self.model_name) | |
| self.tokenizer = None | |
| self.model = None | |
| self.backend_name = "fallback" | |
| def summarize(self, text: str, target_length: int | None = None) -> SummaryOutput: | |
| text = normalize_text(text) | |
| if not text: | |
| return SummaryOutput(summary="", backend=self.backend_name, used_target_length=target_length) | |
| if self.backend_name == "transformer" and self.tokenizer and self.model: | |
| try: | |
| return SummaryOutput( | |
| summary=self._summarize_with_transformer(text, target_length), | |
| backend="transformer", | |
| used_target_length=target_length, | |
| ) | |
| except Exception as exc: | |
| logger.exception("Transformer generation failed") | |
| return SummaryOutput( | |
| summary=self.fallback.summarize(text, target_length=target_length), | |
| backend="fallback", | |
| used_target_length=target_length, | |
| error=f"{type(exc).__name__}: {exc}", | |
| ) | |
| return SummaryOutput( | |
| summary=self.fallback.summarize(text, target_length=target_length), | |
| backend="fallback", | |
| used_target_length=target_length, | |
| ) | |
| def _summarize_with_transformer(self, text: str, target_length: int | None) -> str: | |
| prompt = text | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=SummarizationConfig.max_source_length, | |
| ) | |
| inputs.pop("token_type_ids", None) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| max_new_tokens = max(48, min(192, int((target_length or 120) * 1.1))) | |
| with torch.no_grad(): | |
| generated = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=2, | |
| no_repeat_ngram_size=3, | |
| length_penalty=1.0, | |
| early_stopping=True, | |
| ) | |
| return self.tokenizer.decode( | |
| generated[0], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ).strip() | |
| app = FastAPI(title="Transformer Summarizer Demo", version="1.0.0") | |
| engine = HybridSummarizer() | |
| class SummarizeRequest(BaseModel): | |
| text: str | |
| target_length: int | None = Field(default=120, ge=1, description="目标摘要长度") | |
| class SummarizeResponse(BaseModel): | |
| summary: str | |
| backend: str | |
| target_length: int | None | |
| error: str | None = None | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "backend": engine.backend_name, | |
| "model_name": engine.model_name, | |
| "load_error": engine.load_error, | |
| } | |
| def summarize(req: SummarizeRequest): | |
| result = engine.summarize(req.text, target_length=req.target_length) | |
| return SummarizeResponse( | |
| summary=result.summary, | |
| backend=result.backend, | |
| target_length=result.used_target_length, | |
| error=result.error, | |
| ) | |
| def summarize_plain( | |
| text: str = Body(..., media_type="text/plain", description="直接粘贴原文,支持换行和空格"), | |
| target_length: int = Query(120, ge=1, description="目标摘要长度"), | |
| ): | |
| result = engine.summarize(text, target_length=target_length) | |
| return SummarizeResponse( | |
| summary=result.summary, | |
| backend=result.backend, | |
| target_length=result.used_target_length, | |
| error=result.error, | |
| ) | |
| def root(): | |
| error_note = f"<p>最近一次生成错误:<code>{engine.load_error}</code></p>" if engine.load_error else "" | |
| html = """ | |
| <!DOCTYPE html> | |
| <html lang="zh-CN"> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> | |
| <title>Transformer Summarizer Demo</title> | |
| <style> | |
| body { | |
| margin: 0; | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; | |
| background: linear-gradient(135deg, #f7f8fc 0%, #eef4ff 100%); | |
| color: #1f2937; | |
| } | |
| .wrap { | |
| max-width: 920px; | |
| margin: 0 auto; | |
| padding: 56px 20px 72px; | |
| } | |
| .card { | |
| background: rgba(255, 255, 255, 0.92); | |
| border: 1px solid rgba(148, 163, 184, 0.25); | |
| border-radius: 20px; | |
| padding: 32px; | |
| box-shadow: 0 20px 60px rgba(15, 23, 42, 0.08); | |
| backdrop-filter: blur(8px); | |
| } | |
| h1 { | |
| margin: 0 0 12px; | |
| font-size: 34px; | |
| } | |
| h2 { | |
| margin: 24px 0 10px; | |
| font-size: 22px; | |
| } | |
| p { | |
| line-height: 1.75; | |
| margin: 10px 0; | |
| } | |
| .btns { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 14px; | |
| margin: 28px 0 18px; | |
| } | |
| a.btn { | |
| display: inline-block; | |
| padding: 14px 22px; | |
| border-radius: 12px; | |
| text-decoration: none; | |
| font-weight: 600; | |
| transition: transform 0.15s ease, box-shadow 0.15s ease; | |
| } | |
| a.btn:hover { | |
| transform: translateY(-1px); | |
| } | |
| .primary { | |
| background: #2563eb; | |
| color: white; | |
| box-shadow: 0 10px 20px rgba(37, 99, 235, 0.22); | |
| } | |
| .secondary { | |
| background: white; | |
| color: #2563eb; | |
| border: 1px solid rgba(37, 99, 235, 0.2); | |
| } | |
| .guide { | |
| margin-top: 26px; | |
| padding-top: 18px; | |
| border-top: 1px solid rgba(148, 163, 184, 0.25); | |
| } | |
| code { | |
| background: #eef2ff; | |
| padding: 2px 6px; | |
| border-radius: 6px; | |
| } | |
| pre { | |
| background: #f8fafc; | |
| color: #111827; | |
| padding: 16px; | |
| border-radius: 12px; | |
| overflow-x: auto; | |
| border: 1px solid rgba(148, 163, 184, 0.25); | |
| } | |
| pre code { | |
| background: transparent; | |
| padding: 0; | |
| border-radius: 0; | |
| color: inherit; | |
| } | |
| .meta { | |
| color: #6b7280; | |
| font-size: 14px; | |
| margin-top: 14px; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="wrap"> | |
| <div class="card"> | |
| <h1>Transformer Summarizer Demo</h1> | |
| <p>这是一个基于 Transformer 的中文文本摘要演示系统。你可以通过下面两个按钮进入接口文档或检查服务状态,也可以直接调用摘要接口。</p> | |
| <p>当前模型:<code>{engine.model_name}</code></p> | |
| <p>当前后端:<code>{engine.backend_name}</code></p> | |
| """ + error_note + """ | |
| <div class="btns"> | |
| <a class="btn primary" href="/docs" target="_blank" rel="noreferrer">打开接口文档</a> | |
| <a class="btn secondary" href="/health" target="_blank" rel="noreferrer">检查服务状态</a> | |
| </div> | |
| <div class="guide"> | |
| <h2>使用指南</h2> | |
| <p>1. 点击 <code>打开接口文档</code>,进入 Swagger 页面。</p> | |
| <p>2. 找到 <code>POST /summarize</code>,点击 <code>Try it out</code>。</p> | |
| <p>3. 在请求体中填写文本和目标长度,例如:</p> | |
| <pre><code>{ | |
| "text": "这里放一段较长的中文文本", | |
| "target_length": 120 | |
| }</code></pre> | |
| <p>4. 点击 <code>Execute</code> 后查看返回的摘要结果。</p> | |
| <p>5. 如果想确认服务是否正常,可点击 <code>检查服务状态</code>,返回 <code>ok</code> 即表示运行正常。</p> | |
| <p>6. 如果接口返回 <code>backend=fallback</code>,请查看响应里的 <code>error</code> 字段,这表示 Transformer 生成阶段失败,系统才会自动切回备用摘要。</p> | |
| <p>7. 如果原文包含大量换行或空格,建议直接使用 <code>POST /summarize-plain</code>,把正文当作纯文本提交,更适合粘贴文章正文。</p> | |
| <div class="meta"> | |
| 提示:<code>/summarize</code> 走 JSON,<code>/summarize-plain</code> 走纯文本。前者适合结构化调用,后者适合直接粘贴文章。 | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html) | |