mrfakename's picture
Update app.py
a095adf verified
import hashlib
import os
import socket
import sqlite3
import threading
import time
from functools import lru_cache
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
APP_DIR = Path(__file__).resolve().parent
STATIC_DIR = APP_DIR / "static"
CACHE_DB_PATH = Path(os.getenv("CACHE_DB", str(APP_DIR / "translations.db")))
MODEL_ID = os.getenv("MODEL_ID", "AngelSlim/Hy-MT1.5-1.8B-1.25bit")
MAX_INPUT_CHARS = int(os.getenv("MAX_INPUT_CHARS", "6000"))
DEFAULT_MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256"))
DEFAULT_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
LANGUAGES = {
"auto": "Auto Detect",
"en": "English",
"zh": "Chinese",
"ja": "Japanese",
"ko": "Korean",
"fr": "French",
"de": "German",
"es": "Spanish",
"it": "Italian",
"pt": "Portuguese",
"ru": "Russian",
"ar": "Arabic",
"hi": "Hindi",
"vi": "Vietnamese",
"th": "Thai",
"id": "Indonesian",
"ms": "Malay",
"tr": "Turkish",
"nl": "Dutch",
"pl": "Polish",
"uk": "Ukrainian",
"cs": "Czech",
"sv": "Swedish",
"fi": "Finnish",
"el": "Greek",
"he": "Hebrew",
}
class TranslateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=MAX_INPUT_CHARS)
source_language: str = "auto"
target_language: str = "en"
class TranslateResponse(BaseModel):
translation: str
transliteration: Optional[str] = None
transliteration_label: Optional[str] = None
elapsed_ms: int
model: str
cached: bool = False
_db_lock = threading.Lock()
_db_conn: Optional[sqlite3.Connection] = None
def _get_db() -> sqlite3.Connection:
global _db_conn
if _db_conn is None:
CACHE_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(CACHE_DB_PATH), check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS translations (
model TEXT NOT NULL,
source_language TEXT NOT NULL,
target_language TEXT NOT NULL,
text_hash TEXT NOT NULL,
source_text TEXT NOT NULL,
translation TEXT NOT NULL,
transliteration TEXT,
transliteration_label TEXT,
elapsed_ms INTEGER NOT NULL,
created_at REAL NOT NULL,
PRIMARY KEY (model, source_language, target_language, text_hash)
)
"""
)
conn.commit()
_db_conn = conn
return _db_conn
def _hash_text(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def cache_lookup(text: str, source_language: str, target_language: str) -> Optional[dict]:
conn = _get_db()
with _db_lock:
row = conn.execute(
"""
SELECT translation, transliteration, transliteration_label, elapsed_ms
FROM translations
WHERE model = ? AND source_language = ? AND target_language = ? AND text_hash = ?
""",
(MODEL_ID, source_language, target_language, _hash_text(text)),
).fetchone()
if not row:
return None
return {
"translation": row[0],
"transliteration": row[1],
"transliteration_label": row[2],
"elapsed_ms": row[3],
}
def cache_store(
text: str,
source_language: str,
target_language: str,
translation: str,
transliteration: Optional[str],
transliteration_label: Optional[str],
elapsed_ms: int,
) -> None:
conn = _get_db()
with _db_lock:
conn.execute(
"""
INSERT OR REPLACE INTO translations
(model, source_language, target_language, text_hash, source_text,
translation, transliteration, transliteration_label, elapsed_ms, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
MODEL_ID,
source_language,
target_language,
_hash_text(text),
text,
translation,
transliteration,
transliteration_label,
elapsed_ms,
time.time(),
),
)
conn.commit()
def transliterate(text: str, language: str) -> tuple[Optional[str], Optional[str]]:
if not text:
return None, None
try:
if language == "zh":
from pypinyin import Style, pinyin
tokens = pinyin(text, style=Style.TONE, errors="default")
joined = " ".join("".join(parts) for parts in tokens if parts)
return (joined or None), "Pinyin"
if language == "ja":
import pykakasi
kks = pykakasi.kakasi()
result = kks.convert(text)
joined = " ".join(item["hepburn"] for item in result if item.get("hepburn")).strip()
return (joined or None), "Romaji"
if language in {"ko", "ru", "uk", "el", "he", "ar", "hi", "th"}:
from unidecode import unidecode
romanized = unidecode(text).strip()
if not romanized or romanized == text:
return None, None
label = {
"ko": "Romanization",
"ru": "Romanization",
"uk": "Romanization",
"el": "Romanization",
"he": "Romanization",
"ar": "Romanization",
"hi": "Romanization",
"th": "Romanization",
}[language]
return romanized, label
except ImportError:
return None, None
except Exception:
return None, None
return None, None
app = FastAPI(title="AngelSlim Translate")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
def choose_port(host: str, preferred_port: int) -> int:
for port in range(preferred_port, preferred_port + 50):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind((host, port))
except OSError:
continue
return port
raise RuntimeError(f"No open port found from {preferred_port} to {preferred_port + 49}.")
def language_name(code: str) -> str:
return LANGUAGES.get(code, code)
def build_prompt(text: str, source_language: str, target_language: str) -> str:
target = language_name(target_language)
source_text = text.strip()
if source_language == "zh" or target_language == "zh":
return f"将以下文本翻译为{target},注意只需要输出翻译后的结果,不要额外解释:\n\n{source_text}"
return f"Translate the following segment into {target}, without additional explanation.\n\n{source_text}"
@lru_cache(maxsize=1)
def get_model():
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError as exc:
raise RuntimeError(
"Transformers dependencies are not installed. Run `pip install -r requirements.txt`."
) from exc
dtype = torch.bfloat16 if torch.backends.mps.is_available() or torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=os.getenv("TORCH_DTYPE", "auto") if os.getenv("TORCH_DTYPE") else dtype,
device_map=os.getenv("DEVICE_MAP", "auto"),
trust_remote_code=True,
low_cpu_mem_usage=True,
)
model.eval()
return tokenizer, model
def run_translation(text: str, source_language: str, target_language: str) -> str:
import torch
tokenizer, model = get_model()
prompt = build_prompt(text, source_language, target_language)
messages = [{"role": "user", "content": prompt}]
model_input = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
)
if hasattr(model_input, "items"):
model_input = {key: value.to(model.device) for key, value in model_input.items()}
generation_inputs = model_input
prompt_tokens = model_input["input_ids"].shape[-1]
else:
model_input = model_input.to(model.device)
generation_inputs = {"input_ids": model_input}
prompt_tokens = model_input.shape[-1]
with torch.inference_mode():
output = model.generate(
**generation_inputs,
max_new_tokens=DEFAULT_MAX_TOKENS,
do_sample=DEFAULT_TEMPERATURE > 0,
temperature=DEFAULT_TEMPERATURE if DEFAULT_TEMPERATURE > 0 else None,
top_k=int(os.getenv("TOP_K", "20")),
top_p=float(os.getenv("TOP_P", "0.6")),
repetition_penalty=float(os.getenv("REPETITION_PENALTY", "1.05")),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
generated = output[0, prompt_tokens:]
return tokenizer.decode(generated, skip_special_tokens=True).strip()
@app.get("/")
def index():
return FileResponse(STATIC_DIR / "index.html")
@app.get("/api/config")
def config():
return {
"languages": LANGUAGES,
"model_id": MODEL_ID,
"max_input_chars": MAX_INPUT_CHARS,
}
@app.get("/api/health")
def health():
loaded = get_model.cache_info().currsize > 0
return {
"ok": True,
"model_loaded": loaded,
"model_id": MODEL_ID,
}
@app.get("/api/cache/stats")
def cache_stats():
conn = _get_db()
with _db_lock:
total = conn.execute("SELECT COUNT(*) FROM translations").fetchone()[0]
for_model = conn.execute(
"SELECT COUNT(*) FROM translations WHERE model = ?", (MODEL_ID,)
).fetchone()[0]
return {"total": total, "for_current_model": for_model, "model_id": MODEL_ID}
@app.post("/api/translate", response_model=TranslateResponse)
def translate(payload: TranslateRequest):
if payload.target_language == "auto":
raise HTTPException(status_code=400, detail="Choose a target language.")
cached = cache_lookup(payload.text, payload.source_language, payload.target_language)
if cached:
return TranslateResponse(
translation=cached["translation"],
transliteration=cached["transliteration"],
transliteration_label=cached["transliteration_label"],
elapsed_ms=cached["elapsed_ms"],
model=MODEL_ID,
cached=True,
)
started = time.perf_counter()
try:
translation = run_translation(
payload.text,
payload.source_language,
payload.target_language,
)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
elapsed_ms = round((time.perf_counter() - started) * 1000)
transliteration, transliteration_label = transliterate(translation, payload.target_language)
cache_store(
payload.text,
payload.source_language,
payload.target_language,
translation,
transliteration,
transliteration_label,
elapsed_ms,
)
return TranslateResponse(
translation=translation,
transliteration=transliteration,
transliteration_label=transliteration_label,
elapsed_ms=elapsed_ms,
model=MODEL_ID,
cached=False,
)
if __name__ == "__main__":
import uvicorn
host = os.getenv("HOST", "0.0.0.0")
port = choose_port(host, int(os.getenv("PORT", "7860")))
print(f"AngelSlim Translate running at http://{host}:{port}")
uvicorn.run(
"app:app",
host=host,
port=port,
reload=os.getenv("RELOAD", "0") == "1",
)