import json import threading import time import traceback import gradio as gr import numpy as np import spaces import torch from PIL import Image from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast MODEL_ID = "btrkeks/transcoda-59M-zeroshot-v1" MODEL_REVISION = "b529f8aa5d996d9224df3395b5b92d0867343c91" TARGET_W = 1050 TARGET_H = 1485 _load_lock = threading.Lock() _model = None _tokenizer = None _load_error = None def _log(message: str) -> None: print(f"[transcoda-space] {message}", flush=True) def get_model_and_tokenizer(): global _load_error, _model, _tokenizer if _model is not None and _tokenizer is not None: return _model, _tokenizer if _load_error is not None: raise RuntimeError(_load_error) with _load_lock: if _model is not None and _tokenizer is not None: return _model, _tokenizer if _load_error is not None: raise RuntimeError(_load_error) started = time.time() try: _log(f"Loading {MODEL_ID}@{MODEL_REVISION} on CPU") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, revision=MODEL_REVISION, trust_remote_code=True, low_cpu_mem_usage=False, ) model.eval() tokenizer = PreTrainedTokenizerFast.from_pretrained( MODEL_ID, revision=MODEL_REVISION, ) _model = model _tokenizer = tokenizer _log(f"Model loaded in {time.time() - started:.1f}s") return _model, _tokenizer except Exception: _load_error = traceback.format_exc() _log("Model load failed:\n" + _load_error) raise RuntimeError(_load_error) def preload_model() -> None: try: get_model_and_tokenizer() except Exception: pass def preprocess_pil_image(image: Image.Image) -> torch.Tensor: img = image.convert("RGB") new_h = max(1, int(img.height * (TARGET_W / img.width))) img = img.resize((TARGET_W, new_h), Image.BILINEAR) arr = np.array(img) if arr.shape[0] > TARGET_H: arr = arr[:TARGET_H] elif arr.shape[0] < TARGET_H: pad = np.full((TARGET_H - arr.shape[0], TARGET_W, 3), 255, dtype=arr.dtype) arr = np.concatenate([arr, pad], axis=0) tensor = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0 tensor = (tensor - 0.5) / 0.5 return tensor.unsqueeze(0) @spaces.GPU def transcribe(image, decoding, max_length, num_beams, repetition_penalty): if image is None: raise gr.Error("Upload a score page image.") device = "cuda" started = time.time() try: model, tokenizer = get_model_and_tokenizer() except Exception as exc: raise gr.Error(f"Transcoda failed to load. Check container logs.\n\n{exc}") from exc model.to(device) try: pil_image = image if isinstance(image, Image.Image) else Image.fromarray(image) pixel_values = preprocess_pil_image(pil_image).to(device) image_sizes = torch.tensor([[TARGET_H, TARGET_W]], device=device) beams = 1 if decoding == "greedy" else int(num_beams or 3) input_ids = torch.full( (1, 1), int(model.config.bos_token_id), dtype=torch.long, device=device, ) with torch.no_grad(): output = model.generate( input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes, max_length=int(max_length or 2048), do_sample=False, num_beams=beams, repetition_penalty=float(repetition_penalty or 1.1), ) kern = tokenizer.decode(output[0], skip_special_tokens=True) finally: model.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() metadata = { "model": MODEL_ID, "revision": MODEL_REVISION, "device": device, "decoding": decoding, "num_beams": beams, "max_length": int(max_length or 2048), "repetition_penalty": float(repetition_penalty or 1.1), "elapsed_ms": int((time.time() - started) * 1000), "output_chars": len(kern), } return kern, json.dumps(metadata, indent=2) threading.Thread(target=preload_model, daemon=True).start() with gr.Blocks() as demo: gr.Markdown("# Transcoda OMR") with gr.Row(): image = gr.Image(type="pil", label="Score page image") with gr.Column(): decoding = gr.Radio(["greedy", "beam"], value="greedy", label="Decoding") max_length = gr.Number(value=2048, precision=0, label="Max length") num_beams = gr.Number(value=3, precision=0, label="Beam count") repetition_penalty = gr.Number(value=1.1, label="Repetition penalty") run = gr.Button("Transcribe") kern = gr.Textbox(label="Generated **kern", lines=24) metadata = gr.Code(label="Metadata", language="json") run.click( transcribe, inputs=[image, decoding, max_length, num_beams, repetition_penalty], outputs=[kern, metadata], api_name="transcribe", ) demo.queue().launch()