File size: 5,312 Bytes
0a8e15e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import json
import threading
import time
import traceback
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
MODEL_ID = "btrkeks/transcoda-59M-zeroshot-v1"
MODEL_REVISION = "b529f8aa5d996d9224df3395b5b92d0867343c91"
TARGET_W = 1050
TARGET_H = 1485
_load_lock = threading.Lock()
_model = None
_tokenizer = None
_load_error = None
def _log(message: str) -> None:
print(f"[transcoda-space] {message}", flush=True)
def get_model_and_tokenizer():
global _load_error, _model, _tokenizer
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
if _load_error is not None:
raise RuntimeError(_load_error)
with _load_lock:
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
if _load_error is not None:
raise RuntimeError(_load_error)
started = time.time()
try:
_log(f"Loading {MODEL_ID}@{MODEL_REVISION} on CPU")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
revision=MODEL_REVISION,
trust_remote_code=True,
low_cpu_mem_usage=False,
)
model.eval()
tokenizer = PreTrainedTokenizerFast.from_pretrained(
MODEL_ID,
revision=MODEL_REVISION,
)
_model = model
_tokenizer = tokenizer
_log(f"Model loaded in {time.time() - started:.1f}s")
return _model, _tokenizer
except Exception:
_load_error = traceback.format_exc()
_log("Model load failed:\n" + _load_error)
raise RuntimeError(_load_error)
def preload_model() -> None:
try:
get_model_and_tokenizer()
except Exception:
pass
def preprocess_pil_image(image: Image.Image) -> torch.Tensor:
img = image.convert("RGB")
new_h = max(1, int(img.height * (TARGET_W / img.width)))
img = img.resize((TARGET_W, new_h), Image.BILINEAR)
arr = np.array(img)
if arr.shape[0] > TARGET_H:
arr = arr[:TARGET_H]
elif arr.shape[0] < TARGET_H:
pad = np.full((TARGET_H - arr.shape[0], TARGET_W, 3), 255, dtype=arr.dtype)
arr = np.concatenate([arr, pad], axis=0)
tensor = torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0
tensor = (tensor - 0.5) / 0.5
return tensor.unsqueeze(0)
@spaces.GPU
def transcribe(image, decoding, max_length, num_beams, repetition_penalty):
if image is None:
raise gr.Error("Upload a score page image.")
device = "cuda"
started = time.time()
try:
model, tokenizer = get_model_and_tokenizer()
except Exception as exc:
raise gr.Error(f"Transcoda failed to load. Check container logs.\n\n{exc}") from exc
model.to(device)
try:
pil_image = image if isinstance(image, Image.Image) else Image.fromarray(image)
pixel_values = preprocess_pil_image(pil_image).to(device)
image_sizes = torch.tensor([[TARGET_H, TARGET_W]], device=device)
beams = 1 if decoding == "greedy" else int(num_beams or 3)
input_ids = torch.full(
(1, 1),
int(model.config.bos_token_id),
dtype=torch.long,
device=device,
)
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
image_sizes=image_sizes,
max_length=int(max_length or 2048),
do_sample=False,
num_beams=beams,
repetition_penalty=float(repetition_penalty or 1.1),
)
kern = tokenizer.decode(output[0], skip_special_tokens=True)
finally:
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
metadata = {
"model": MODEL_ID,
"revision": MODEL_REVISION,
"device": device,
"decoding": decoding,
"num_beams": beams,
"max_length": int(max_length or 2048),
"repetition_penalty": float(repetition_penalty or 1.1),
"elapsed_ms": int((time.time() - started) * 1000),
"output_chars": len(kern),
}
return kern, json.dumps(metadata, indent=2)
threading.Thread(target=preload_model, daemon=True).start()
with gr.Blocks() as demo:
gr.Markdown("# Transcoda OMR")
with gr.Row():
image = gr.Image(type="pil", label="Score page image")
with gr.Column():
decoding = gr.Radio(["greedy", "beam"], value="greedy", label="Decoding")
max_length = gr.Number(value=2048, precision=0, label="Max length")
num_beams = gr.Number(value=3, precision=0, label="Beam count")
repetition_penalty = gr.Number(value=1.1, label="Repetition penalty")
run = gr.Button("Transcribe")
kern = gr.Textbox(label="Generated **kern", lines=24)
metadata = gr.Code(label="Metadata", language="json")
run.click(
transcribe,
inputs=[image, decoding, max_length, num_beams, repetition_penalty],
outputs=[kern, metadata],
api_name="transcribe",
)
demo.queue().launch()
|