import hashlib import os import socket import sqlite3 import threading import time from functools import lru_cache from pathlib import Path from typing import Optional from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field APP_DIR = Path(__file__).resolve().parent STATIC_DIR = APP_DIR / "static" CACHE_DB_PATH = Path(os.getenv("CACHE_DB", str(APP_DIR / "translations.db"))) MODEL_ID = os.getenv("MODEL_ID", "AngelSlim/Hy-MT1.5-1.8B-1.25bit") MAX_INPUT_CHARS = int(os.getenv("MAX_INPUT_CHARS", "6000")) DEFAULT_MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256")) DEFAULT_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) LANGUAGES = { "auto": "Auto Detect", "en": "English", "zh": "Chinese", "ja": "Japanese", "ko": "Korean", "fr": "French", "de": "German", "es": "Spanish", "it": "Italian", "pt": "Portuguese", "ru": "Russian", "ar": "Arabic", "hi": "Hindi", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian", "ms": "Malay", "tr": "Turkish", "nl": "Dutch", "pl": "Polish", "uk": "Ukrainian", "cs": "Czech", "sv": "Swedish", "fi": "Finnish", "el": "Greek", "he": "Hebrew", } class TranslateRequest(BaseModel): text: str = Field(..., min_length=1, max_length=MAX_INPUT_CHARS) source_language: str = "auto" target_language: str = "en" class TranslateResponse(BaseModel): translation: str transliteration: Optional[str] = None transliteration_label: Optional[str] = None elapsed_ms: int model: str cached: bool = False _db_lock = threading.Lock() _db_conn: Optional[sqlite3.Connection] = None def _get_db() -> sqlite3.Connection: global _db_conn if _db_conn is None: CACHE_DB_PATH.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(CACHE_DB_PATH), check_same_thread=False) conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") conn.execute( """ CREATE TABLE IF NOT EXISTS translations ( model TEXT NOT NULL, source_language TEXT NOT NULL, target_language TEXT NOT NULL, text_hash TEXT NOT NULL, source_text TEXT NOT NULL, translation TEXT NOT NULL, transliteration TEXT, transliteration_label TEXT, elapsed_ms INTEGER NOT NULL, created_at REAL NOT NULL, PRIMARY KEY (model, source_language, target_language, text_hash) ) """ ) conn.commit() _db_conn = conn return _db_conn def _hash_text(text: str) -> str: return hashlib.sha256(text.encode("utf-8")).hexdigest() def cache_lookup(text: str, source_language: str, target_language: str) -> Optional[dict]: conn = _get_db() with _db_lock: row = conn.execute( """ SELECT translation, transliteration, transliteration_label, elapsed_ms FROM translations WHERE model = ? AND source_language = ? AND target_language = ? AND text_hash = ? """, (MODEL_ID, source_language, target_language, _hash_text(text)), ).fetchone() if not row: return None return { "translation": row[0], "transliteration": row[1], "transliteration_label": row[2], "elapsed_ms": row[3], } def cache_store( text: str, source_language: str, target_language: str, translation: str, transliteration: Optional[str], transliteration_label: Optional[str], elapsed_ms: int, ) -> None: conn = _get_db() with _db_lock: conn.execute( """ INSERT OR REPLACE INTO translations (model, source_language, target_language, text_hash, source_text, translation, transliteration, transliteration_label, elapsed_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( MODEL_ID, source_language, target_language, _hash_text(text), text, translation, transliteration, transliteration_label, elapsed_ms, time.time(), ), ) conn.commit() def transliterate(text: str, language: str) -> tuple[Optional[str], Optional[str]]: if not text: return None, None try: if language == "zh": from pypinyin import Style, pinyin tokens = pinyin(text, style=Style.TONE, errors="default") joined = " ".join("".join(parts) for parts in tokens if parts) return (joined or None), "Pinyin" if language == "ja": import pykakasi kks = pykakasi.kakasi() result = kks.convert(text) joined = " ".join(item["hepburn"] for item in result if item.get("hepburn")).strip() return (joined or None), "Romaji" if language in {"ko", "ru", "uk", "el", "he", "ar", "hi", "th"}: from unidecode import unidecode romanized = unidecode(text).strip() if not romanized or romanized == text: return None, None label = { "ko": "Romanization", "ru": "Romanization", "uk": "Romanization", "el": "Romanization", "he": "Romanization", "ar": "Romanization", "hi": "Romanization", "th": "Romanization", }[language] return romanized, label except ImportError: return None, None except Exception: return None, None return None, None app = FastAPI(title="AngelSlim Translate") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") def choose_port(host: str, preferred_port: int) -> int: for port in range(preferred_port, preferred_port + 50): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: sock.bind((host, port)) except OSError: continue return port raise RuntimeError(f"No open port found from {preferred_port} to {preferred_port + 49}.") def language_name(code: str) -> str: return LANGUAGES.get(code, code) def build_prompt(text: str, source_language: str, target_language: str) -> str: target = language_name(target_language) source_text = text.strip() if source_language == "zh" or target_language == "zh": return f"将以下文本翻译为{target},注意只需要输出翻译后的结果,不要额外解释:\n\n{source_text}" return f"Translate the following segment into {target}, without additional explanation.\n\n{source_text}" @lru_cache(maxsize=1) def get_model(): try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError as exc: raise RuntimeError( "Transformers dependencies are not installed. Run `pip install -r requirements.txt`." ) from exc dtype = torch.bfloat16 if torch.backends.mps.is_available() or torch.cuda.is_available() else torch.float32 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=os.getenv("TORCH_DTYPE", "auto") if os.getenv("TORCH_DTYPE") else dtype, device_map=os.getenv("DEVICE_MAP", "auto"), trust_remote_code=True, low_cpu_mem_usage=True, ) model.eval() return tokenizer, model def run_translation(text: str, source_language: str, target_language: str) -> str: import torch tokenizer, model = get_model() prompt = build_prompt(text, source_language, target_language) messages = [{"role": "user", "content": prompt}] model_input = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ) if hasattr(model_input, "items"): model_input = {key: value.to(model.device) for key, value in model_input.items()} generation_inputs = model_input prompt_tokens = model_input["input_ids"].shape[-1] else: model_input = model_input.to(model.device) generation_inputs = {"input_ids": model_input} prompt_tokens = model_input.shape[-1] with torch.inference_mode(): output = model.generate( **generation_inputs, max_new_tokens=DEFAULT_MAX_TOKENS, do_sample=DEFAULT_TEMPERATURE > 0, temperature=DEFAULT_TEMPERATURE if DEFAULT_TEMPERATURE > 0 else None, top_k=int(os.getenv("TOP_K", "20")), top_p=float(os.getenv("TOP_P", "0.6")), repetition_penalty=float(os.getenv("REPETITION_PENALTY", "1.05")), eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, ) generated = output[0, prompt_tokens:] return tokenizer.decode(generated, skip_special_tokens=True).strip() @app.get("/") def index(): return FileResponse(STATIC_DIR / "index.html") @app.get("/api/config") def config(): return { "languages": LANGUAGES, "model_id": MODEL_ID, "max_input_chars": MAX_INPUT_CHARS, } @app.get("/api/health") def health(): loaded = get_model.cache_info().currsize > 0 return { "ok": True, "model_loaded": loaded, "model_id": MODEL_ID, } @app.get("/api/cache/stats") def cache_stats(): conn = _get_db() with _db_lock: total = conn.execute("SELECT COUNT(*) FROM translations").fetchone()[0] for_model = conn.execute( "SELECT COUNT(*) FROM translations WHERE model = ?", (MODEL_ID,) ).fetchone()[0] return {"total": total, "for_current_model": for_model, "model_id": MODEL_ID} @app.post("/api/translate", response_model=TranslateResponse) def translate(payload: TranslateRequest): if payload.target_language == "auto": raise HTTPException(status_code=400, detail="Choose a target language.") cached = cache_lookup(payload.text, payload.source_language, payload.target_language) if cached: return TranslateResponse( translation=cached["translation"], transliteration=cached["transliteration"], transliteration_label=cached["transliteration_label"], elapsed_ms=cached["elapsed_ms"], model=MODEL_ID, cached=True, ) started = time.perf_counter() try: translation = run_translation( payload.text, payload.source_language, payload.target_language, ) except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc elapsed_ms = round((time.perf_counter() - started) * 1000) transliteration, transliteration_label = transliterate(translation, payload.target_language) cache_store( payload.text, payload.source_language, payload.target_language, translation, transliteration, transliteration_label, elapsed_ms, ) return TranslateResponse( translation=translation, transliteration=transliteration, transliteration_label=transliteration_label, elapsed_ms=elapsed_ms, model=MODEL_ID, cached=False, ) if __name__ == "__main__": import uvicorn host = os.getenv("HOST", "0.0.0.0") port = choose_port(host, int(os.getenv("PORT", "7860"))) print(f"AngelSlim Translate running at http://{host}:{port}") uvicorn.run( "app:app", host=host, port=port, reload=os.getenv("RELOAD", "0") == "1", )