transcoda / app.py
Jamie Hlusko
Add ZeroGPU-compatible app.py and requirements
0a8e15e
import json
import threading
import time
import traceback
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
MODEL_ID = "btrkeks/transcoda-59M-zeroshot-v1"
MODEL_REVISION = "b529f8aa5d996d9224df3395b5b92d0867343c91"
TARGET_W = 1050
TARGET_H = 1485
_load_lock = threading.Lock()
_model = None
_tokenizer = None
_load_error = None
def _log(message: str) -> None:
print(f"[transcoda-space] {message}", flush=True)
def get_model_and_tokenizer():
global _load_error, _model, _tokenizer
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
if _load_error is not None:
raise RuntimeError(_load_error)
with _load_lock:
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
if _load_error is not None:
raise RuntimeError(_load_error)
started = time.time()
try:
_log(f"Loading {MODEL_ID}@{MODEL_REVISION} on CPU")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
revision=MODEL_REVISION,
trust_remote_code=True,
low_cpu_mem_usage=False,
)
model.eval()
tokenizer = PreTrainedTokenizerFast.from_pretrained(
MODEL_ID,
revision=MODEL_REVISION,
)
_model = model
_tokenizer = tokenizer
_log(f"Model loaded in {time.time() - started:.1f}s")
return _model, _tokenizer
except Exception:
_load_error = traceback.format_exc()
_log("Model load failed:\n" + _load_error)
raise RuntimeError(_load_error)
def preload_model() -> None:
try:
get_model_and_tokenizer()
except Exception:
pass
def preprocess_pil_image(image: Image.Image) -> torch.Tensor:
img = image.convert("RGB")
new_h = max(1, int(img.height * (TARGET_W / img.width)))
img = img.resize((TARGET_W, new_h), Image.BILINEAR)
arr = np.array(img)
if arr.shape[0] > TARGET_H:
arr = arr[:TARGET_H]
elif arr.shape[0] < TARGET_H:
pad = np.full((TARGET_H - arr.shape[0], TARGET_W, 3), 255, dtype=arr.dtype)
arr = np.concatenate([arr, pad], axis=0)
tensor = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
tensor = (tensor - 0.5) / 0.5
return tensor.unsqueeze(0)
@spaces.GPU
def transcribe(image, decoding, max_length, num_beams, repetition_penalty):
if image is None:
raise gr.Error("Upload a score page image.")
device = "cuda"
started = time.time()
try:
model, tokenizer = get_model_and_tokenizer()
except Exception as exc:
raise gr.Error(f"Transcoda failed to load. Check container logs.\n\n{exc}") from exc
model.to(device)
try:
pil_image = image if isinstance(image, Image.Image) else Image.fromarray(image)
pixel_values = preprocess_pil_image(pil_image).to(device)
image_sizes = torch.tensor([[TARGET_H, TARGET_W]], device=device)
beams = 1 if decoding == "greedy" else int(num_beams or 3)
input_ids = torch.full(
(1, 1),
int(model.config.bos_token_id),
dtype=torch.long,
device=device,
)
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
image_sizes=image_sizes,
max_length=int(max_length or 2048),
do_sample=False,
num_beams=beams,
repetition_penalty=float(repetition_penalty or 1.1),
)
kern = tokenizer.decode(output[0], skip_special_tokens=True)
finally:
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
metadata = {
"model": MODEL_ID,
"revision": MODEL_REVISION,
"device": device,
"decoding": decoding,
"num_beams": beams,
"max_length": int(max_length or 2048),
"repetition_penalty": float(repetition_penalty or 1.1),
"elapsed_ms": int((time.time() - started) * 1000),
"output_chars": len(kern),
}
return kern, json.dumps(metadata, indent=2)
threading.Thread(target=preload_model, daemon=True).start()
with gr.Blocks() as demo:
gr.Markdown("# Transcoda OMR")
with gr.Row():
image = gr.Image(type="pil", label="Score page image")
with gr.Column():
decoding = gr.Radio(["greedy", "beam"], value="greedy", label="Decoding")
max_length = gr.Number(value=2048, precision=0, label="Max length")
num_beams = gr.Number(value=3, precision=0, label="Beam count")
repetition_penalty = gr.Number(value=1.1, label="Repetition penalty")
run = gr.Button("Transcribe")
kern = gr.Textbox(label="Generated **kern", lines=24)
metadata = gr.Code(label="Metadata", language="json")
run.click(
transcribe,
inputs=[image, decoding, max_length, num_beams, repetition_penalty],
outputs=[kern, metadata],
api_name="transcribe",
)
demo.queue().launch()