czjun commited on
Commit
c5e3761
·
1 Parent(s): d1a8e7e
Files changed (4) hide show
  1. Dockerfile +16 -0
  2. __pycache__/app.cpython-310.pyc +0 -0
  3. app.py +189 -0
  4. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+ ENV PYTHONUNBUFFERED=1
7
+
8
+ COPY requirements.txt /app/requirements.txt
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ COPY . /app
12
+
13
+ EXPOSE 7860
14
+
15
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
16
+
__pycache__/app.cpython-310.pyc ADDED
Binary file (7 kB). View file
 
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+
6
+ from fastapi import FastAPI
7
+ from pydantic import BaseModel, Field
8
+
9
+ try:
10
+ import torch
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
+ except Exception: # pragma: no cover
13
+ torch = None
14
+ AutoModelForSeq2SeqLM = None
15
+ AutoTokenizer = None
16
+
17
+
18
+ @dataclass
19
+ class SummaryOutput:
20
+ summary: str
21
+ backend: str
22
+ used_target_length: Optional[int]
23
+
24
+
25
+ class SummarizationConfig:
26
+ model_name: str = "google/mt5-small"
27
+ max_source_length: int = 1024
28
+ max_target_length: int = 160
29
+ num_beams: int = 4
30
+ no_repeat_ngram_size: int = 3
31
+ length_penalty: float = 1.0
32
+ fallback_sentences: int = 3
33
+
34
+
35
+ def normalize_text(text: str) -> str:
36
+ return " ".join(text.replace("\u3000", " ").split())
37
+
38
+
39
+ def split_sentences(text: str) -> List[str]:
40
+ import re
41
+
42
+ parts = re.split(r"(?<=[。!?!?;;])\s*", text)
43
+ return [p.strip() for p in parts if p.strip()]
44
+
45
+
46
+ def tokenize(text: str) -> List[str]:
47
+ import re
48
+
49
+ return re.findall(r"[\u4e00-\u9fff]+|[A-Za-z0-9]+", text.lower())
50
+
51
+
52
+ class SimpleExtractiveSummarizer:
53
+ def __init__(self, max_sentences: int = 3):
54
+ self.max_sentences = max_sentences
55
+
56
+ def summarize(self, text: str, target_length: int | None = None) -> str:
57
+ sentences = split_sentences(text)
58
+ if not sentences:
59
+ return ""
60
+ if len(sentences) == 1:
61
+ return sentences[0]
62
+
63
+ freq = {}
64
+ for sentence in sentences:
65
+ for token in tokenize(sentence):
66
+ freq[token] = freq.get(token, 0) + 1
67
+
68
+ scored = []
69
+ for idx, sentence in enumerate(sentences):
70
+ tokens = tokenize(sentence)
71
+ score = sum(freq.get(token, 0) for token in tokens) / max(1, len(tokens))
72
+ scored.append((score, idx, sentence))
73
+
74
+ scored.sort(key=lambda item: (-item[0], item[1]))
75
+ selected = sorted(scored[: self.max_sentences], key=lambda item: item[1])
76
+ kept: List[str] = []
77
+ total = 0
78
+ for _, _, sentence in selected:
79
+ if target_length is not None and kept and total + len(sentence) > target_length:
80
+ break
81
+ kept.append(sentence)
82
+ total += len(sentence)
83
+ return "".join(kept or [selected[0][2]])
84
+
85
+
86
+ class HybridSummarizer:
87
+ def __init__(self, model_name: str = "google/mt5-small"):
88
+ self.model_name = model_name
89
+ self.backend_name = "fallback"
90
+ self.tokenizer = None
91
+ self.model = None
92
+ self.fallback = SimpleExtractiveSummarizer()
93
+ self.device = "cpu"
94
+ self._try_load_transformer()
95
+
96
+ def _try_load_transformer(self) -> None:
97
+ if AutoTokenizer is None or AutoModelForSeq2SeqLM is None or torch is None:
98
+ return
99
+ try:
100
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
101
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
102
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
103
+ self.model.to(self.device)
104
+ self.backend_name = "transformer"
105
+ except Exception:
106
+ self.tokenizer = None
107
+ self.model = None
108
+ self.backend_name = "fallback"
109
+
110
+ def summarize(self, text: str, target_length: int | None = None) -> SummaryOutput:
111
+ text = normalize_text(text)
112
+ if not text:
113
+ return SummaryOutput(summary="", backend=self.backend_name, used_target_length=target_length)
114
+ if self.backend_name == "transformer" and self.tokenizer and self.model:
115
+ try:
116
+ return SummaryOutput(
117
+ summary=self._summarize_with_transformer(text, target_length),
118
+ backend="transformer",
119
+ used_target_length=target_length,
120
+ )
121
+ except Exception:
122
+ pass
123
+ return SummaryOutput(
124
+ summary=self.fallback.summarize(text, target_length=target_length),
125
+ backend="fallback",
126
+ used_target_length=target_length,
127
+ )
128
+
129
+ def _summarize_with_transformer(self, text: str, target_length: int | None) -> str:
130
+ prompt = f"请根据目标长度 {target_length or 120} 字生成摘要:{text}"
131
+ inputs = self.tokenizer(
132
+ prompt,
133
+ return_tensors="pt",
134
+ truncation=True,
135
+ max_length=SummarizationConfig.max_source_length,
136
+ )
137
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
138
+ max_new_tokens = max(32, min(256, int((target_length or 120) * 1.2)))
139
+ min_new_tokens = max(16, int(max_new_tokens * 0.4))
140
+ generated = self.model.generate(
141
+ **inputs,
142
+ max_new_tokens=max_new_tokens,
143
+ min_new_tokens=min_new_tokens,
144
+ num_beams=SummarizationConfig.num_beams,
145
+ no_repeat_ngram_size=SummarizationConfig.no_repeat_ngram_size,
146
+ length_penalty=SummarizationConfig.length_penalty,
147
+ early_stopping=True,
148
+ )
149
+ return self.tokenizer.decode(generated[0], skip_special_tokens=True).strip()
150
+
151
+
152
+ app = FastAPI(title="Transformer Summarizer Demo", version="1.0.0")
153
+ engine = HybridSummarizer()
154
+
155
+
156
+ class SummarizeRequest(BaseModel):
157
+ text: str
158
+ target_length: int | None = Field(default=120, ge=1, description="目标摘要长度")
159
+
160
+
161
+ class SummarizeResponse(BaseModel):
162
+ summary: str
163
+ backend: str
164
+ target_length: int | None
165
+
166
+
167
+ @app.get("/health")
168
+ def health():
169
+ return {"status": "ok", "backend": engine.backend_name}
170
+
171
+
172
+ @app.post("/summarize", response_model=SummarizeResponse)
173
+ def summarize(req: SummarizeRequest):
174
+ result = engine.summarize(req.text, target_length=req.target_length)
175
+ return SummarizeResponse(
176
+ summary=result.summary,
177
+ backend=result.backend,
178
+ target_length=result.used_target_length,
179
+ )
180
+
181
+
182
+ @app.get("/")
183
+ def root():
184
+ return {
185
+ "message": "Transformer Summarizer Demo is running",
186
+ "docs": "/docs",
187
+ "health": "/health",
188
+ }
189
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn>=0.29.0
3
+ pydantic>=2.7.0
4
+ transformers>=4.41.0
5
+ sentencepiece>=0.2.0
6
+ torch>=2.1.0
7
+