Spaces:
Runtime error
Runtime error
| import hashlib | |
| import os | |
| import socket | |
| import sqlite3 | |
| import threading | |
| import time | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| APP_DIR = Path(__file__).resolve().parent | |
| STATIC_DIR = APP_DIR / "static" | |
| CACHE_DB_PATH = Path(os.getenv("CACHE_DB", str(APP_DIR / "translations.db"))) | |
| MODEL_ID = os.getenv("MODEL_ID", "AngelSlim/Hy-MT1.5-1.8B-1.25bit") | |
| MAX_INPUT_CHARS = int(os.getenv("MAX_INPUT_CHARS", "6000")) | |
| DEFAULT_MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256")) | |
| DEFAULT_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) | |
| LANGUAGES = { | |
| "auto": "Auto Detect", | |
| "en": "English", | |
| "zh": "Chinese", | |
| "ja": "Japanese", | |
| "ko": "Korean", | |
| "fr": "French", | |
| "de": "German", | |
| "es": "Spanish", | |
| "it": "Italian", | |
| "pt": "Portuguese", | |
| "ru": "Russian", | |
| "ar": "Arabic", | |
| "hi": "Hindi", | |
| "vi": "Vietnamese", | |
| "th": "Thai", | |
| "id": "Indonesian", | |
| "ms": "Malay", | |
| "tr": "Turkish", | |
| "nl": "Dutch", | |
| "pl": "Polish", | |
| "uk": "Ukrainian", | |
| "cs": "Czech", | |
| "sv": "Swedish", | |
| "fi": "Finnish", | |
| "el": "Greek", | |
| "he": "Hebrew", | |
| } | |
| class TranslateRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=MAX_INPUT_CHARS) | |
| source_language: str = "auto" | |
| target_language: str = "en" | |
| class TranslateResponse(BaseModel): | |
| translation: str | |
| transliteration: Optional[str] = None | |
| transliteration_label: Optional[str] = None | |
| elapsed_ms: int | |
| model: str | |
| cached: bool = False | |
| _db_lock = threading.Lock() | |
| _db_conn: Optional[sqlite3.Connection] = None | |
| def _get_db() -> sqlite3.Connection: | |
| global _db_conn | |
| if _db_conn is None: | |
| CACHE_DB_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| conn = sqlite3.connect(str(CACHE_DB_PATH), check_same_thread=False) | |
| conn.execute("PRAGMA journal_mode=WAL") | |
| conn.execute("PRAGMA synchronous=NORMAL") | |
| conn.execute( | |
| """ | |
| CREATE TABLE IF NOT EXISTS translations ( | |
| model TEXT NOT NULL, | |
| source_language TEXT NOT NULL, | |
| target_language TEXT NOT NULL, | |
| text_hash TEXT NOT NULL, | |
| source_text TEXT NOT NULL, | |
| translation TEXT NOT NULL, | |
| transliteration TEXT, | |
| transliteration_label TEXT, | |
| elapsed_ms INTEGER NOT NULL, | |
| created_at REAL NOT NULL, | |
| PRIMARY KEY (model, source_language, target_language, text_hash) | |
| ) | |
| """ | |
| ) | |
| conn.commit() | |
| _db_conn = conn | |
| return _db_conn | |
| def _hash_text(text: str) -> str: | |
| return hashlib.sha256(text.encode("utf-8")).hexdigest() | |
| def cache_lookup(text: str, source_language: str, target_language: str) -> Optional[dict]: | |
| conn = _get_db() | |
| with _db_lock: | |
| row = conn.execute( | |
| """ | |
| SELECT translation, transliteration, transliteration_label, elapsed_ms | |
| FROM translations | |
| WHERE model = ? AND source_language = ? AND target_language = ? AND text_hash = ? | |
| """, | |
| (MODEL_ID, source_language, target_language, _hash_text(text)), | |
| ).fetchone() | |
| if not row: | |
| return None | |
| return { | |
| "translation": row[0], | |
| "transliteration": row[1], | |
| "transliteration_label": row[2], | |
| "elapsed_ms": row[3], | |
| } | |
| def cache_store( | |
| text: str, | |
| source_language: str, | |
| target_language: str, | |
| translation: str, | |
| transliteration: Optional[str], | |
| transliteration_label: Optional[str], | |
| elapsed_ms: int, | |
| ) -> None: | |
| conn = _get_db() | |
| with _db_lock: | |
| conn.execute( | |
| """ | |
| INSERT OR REPLACE INTO translations | |
| (model, source_language, target_language, text_hash, source_text, | |
| translation, transliteration, transliteration_label, elapsed_ms, created_at) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, | |
| ( | |
| MODEL_ID, | |
| source_language, | |
| target_language, | |
| _hash_text(text), | |
| text, | |
| translation, | |
| transliteration, | |
| transliteration_label, | |
| elapsed_ms, | |
| time.time(), | |
| ), | |
| ) | |
| conn.commit() | |
| def transliterate(text: str, language: str) -> tuple[Optional[str], Optional[str]]: | |
| if not text: | |
| return None, None | |
| try: | |
| if language == "zh": | |
| from pypinyin import Style, pinyin | |
| tokens = pinyin(text, style=Style.TONE, errors="default") | |
| joined = " ".join("".join(parts) for parts in tokens if parts) | |
| return (joined or None), "Pinyin" | |
| if language == "ja": | |
| import pykakasi | |
| kks = pykakasi.kakasi() | |
| result = kks.convert(text) | |
| joined = " ".join(item["hepburn"] for item in result if item.get("hepburn")).strip() | |
| return (joined or None), "Romaji" | |
| if language in {"ko", "ru", "uk", "el", "he", "ar", "hi", "th"}: | |
| from unidecode import unidecode | |
| romanized = unidecode(text).strip() | |
| if not romanized or romanized == text: | |
| return None, None | |
| label = { | |
| "ko": "Romanization", | |
| "ru": "Romanization", | |
| "uk": "Romanization", | |
| "el": "Romanization", | |
| "he": "Romanization", | |
| "ar": "Romanization", | |
| "hi": "Romanization", | |
| "th": "Romanization", | |
| }[language] | |
| return romanized, label | |
| except ImportError: | |
| return None, None | |
| except Exception: | |
| return None, None | |
| return None, None | |
| app = FastAPI(title="AngelSlim Translate") | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| def choose_port(host: str, preferred_port: int) -> int: | |
| for port in range(preferred_port, preferred_port + 50): | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: | |
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| try: | |
| sock.bind((host, port)) | |
| except OSError: | |
| continue | |
| return port | |
| raise RuntimeError(f"No open port found from {preferred_port} to {preferred_port + 49}.") | |
| def language_name(code: str) -> str: | |
| return LANGUAGES.get(code, code) | |
| def build_prompt(text: str, source_language: str, target_language: str) -> str: | |
| target = language_name(target_language) | |
| source_text = text.strip() | |
| if source_language == "zh" or target_language == "zh": | |
| return f"将以下文本翻译为{target},注意只需要输出翻译后的结果,不要额外解释:\n\n{source_text}" | |
| return f"Translate the following segment into {target}, without additional explanation.\n\n{source_text}" | |
| def get_model(): | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| except ImportError as exc: | |
| raise RuntimeError( | |
| "Transformers dependencies are not installed. Run `pip install -r requirements.txt`." | |
| ) from exc | |
| dtype = torch.bfloat16 if torch.backends.mps.is_available() or torch.cuda.is_available() else torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=os.getenv("TORCH_DTYPE", "auto") if os.getenv("TORCH_DTYPE") else dtype, | |
| device_map=os.getenv("DEVICE_MAP", "auto"), | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model.eval() | |
| return tokenizer, model | |
| def run_translation(text: str, source_language: str, target_language: str) -> str: | |
| import torch | |
| tokenizer, model = get_model() | |
| prompt = build_prompt(text, source_language, target_language) | |
| messages = [{"role": "user", "content": prompt}] | |
| model_input = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ) | |
| if hasattr(model_input, "items"): | |
| model_input = {key: value.to(model.device) for key, value in model_input.items()} | |
| generation_inputs = model_input | |
| prompt_tokens = model_input["input_ids"].shape[-1] | |
| else: | |
| model_input = model_input.to(model.device) | |
| generation_inputs = {"input_ids": model_input} | |
| prompt_tokens = model_input.shape[-1] | |
| with torch.inference_mode(): | |
| output = model.generate( | |
| **generation_inputs, | |
| max_new_tokens=DEFAULT_MAX_TOKENS, | |
| do_sample=DEFAULT_TEMPERATURE > 0, | |
| temperature=DEFAULT_TEMPERATURE if DEFAULT_TEMPERATURE > 0 else None, | |
| top_k=int(os.getenv("TOP_K", "20")), | |
| top_p=float(os.getenv("TOP_P", "0.6")), | |
| repetition_penalty=float(os.getenv("REPETITION_PENALTY", "1.05")), | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| generated = output[0, prompt_tokens:] | |
| return tokenizer.decode(generated, skip_special_tokens=True).strip() | |
| def index(): | |
| return FileResponse(STATIC_DIR / "index.html") | |
| def config(): | |
| return { | |
| "languages": LANGUAGES, | |
| "model_id": MODEL_ID, | |
| "max_input_chars": MAX_INPUT_CHARS, | |
| } | |
| def health(): | |
| loaded = get_model.cache_info().currsize > 0 | |
| return { | |
| "ok": True, | |
| "model_loaded": loaded, | |
| "model_id": MODEL_ID, | |
| } | |
| def cache_stats(): | |
| conn = _get_db() | |
| with _db_lock: | |
| total = conn.execute("SELECT COUNT(*) FROM translations").fetchone()[0] | |
| for_model = conn.execute( | |
| "SELECT COUNT(*) FROM translations WHERE model = ?", (MODEL_ID,) | |
| ).fetchone()[0] | |
| return {"total": total, "for_current_model": for_model, "model_id": MODEL_ID} | |
| def translate(payload: TranslateRequest): | |
| if payload.target_language == "auto": | |
| raise HTTPException(status_code=400, detail="Choose a target language.") | |
| cached = cache_lookup(payload.text, payload.source_language, payload.target_language) | |
| if cached: | |
| return TranslateResponse( | |
| translation=cached["translation"], | |
| transliteration=cached["transliteration"], | |
| transliteration_label=cached["transliteration_label"], | |
| elapsed_ms=cached["elapsed_ms"], | |
| model=MODEL_ID, | |
| cached=True, | |
| ) | |
| started = time.perf_counter() | |
| try: | |
| translation = run_translation( | |
| payload.text, | |
| payload.source_language, | |
| payload.target_language, | |
| ) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| elapsed_ms = round((time.perf_counter() - started) * 1000) | |
| transliteration, transliteration_label = transliterate(translation, payload.target_language) | |
| cache_store( | |
| payload.text, | |
| payload.source_language, | |
| payload.target_language, | |
| translation, | |
| transliteration, | |
| transliteration_label, | |
| elapsed_ms, | |
| ) | |
| return TranslateResponse( | |
| translation=translation, | |
| transliteration=transliteration, | |
| transliteration_label=transliteration_label, | |
| elapsed_ms=elapsed_ms, | |
| model=MODEL_ID, | |
| cached=False, | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| host = os.getenv("HOST", "0.0.0.0") | |
| port = choose_port(host, int(os.getenv("PORT", "7860"))) | |
| print(f"AngelSlim Translate running at http://{host}:{port}") | |
| uvicorn.run( | |
| "app:app", | |
| host=host, | |
| port=port, | |
| reload=os.getenv("RELOAD", "0") == "1", | |
| ) | |