import streamlit as st
import torch
import cv2
import numpy as np
import easyocr
import os
import io
import time
from gtts import gTTS
from PIL import Image, ImageOps
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# ═══════════════════════════════════════════════════════════════
# UI CONFIGURATION & ATOMIC CSS OVERRIDES
# ═══════════════════════════════════════════════════════════════
st.set_page_config(page_title="Handwriting Engine", layout="wide", initial_sidebar_state="collapsed")
st.markdown("""
""", unsafe_allow_html=True)
# ═══════════════════════════════════════════════════════════════
# MODELS & OCR LOGIC
# ═══════════════════════════════════════════════════════════════
# THE KILL-SWITCH: show_spinner=False completely deletes the un-styleable white cache boxes
@st.cache_resource(show_spinner=False)
def load_vision_engine():
import logging
logging.getLogger("easyocr").setLevel(logging.ERROR)
return easyocr.Reader(['en'], gpu=torch.cuda.is_available())
@st.cache_resource(show_spinner=False)
def load_trocr_model(model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
proc = TrOCRProcessor.from_pretrained(model_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path)
model.to(device)
if device.type == "cuda":
model = model.half()
# ─── THE ACTUAL ROOT-CAUSE FIX ───
for module in model.modules():
if "TrOCRSinusoidalPositionalEmbedding" in module.__class__.__name__:
num_positions, embedding_dim = module.weights.shape
new_weights = module.__class__.get_embedding(
num_positions,
embedding_dim,
padding_idx=getattr(module, "padding_idx", None)
)
module.weights = new_weights.to(device=device, dtype=model.dtype)
model.eval()
return proc, model, device
def extract_lines(pil_img, reader):
img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
results = reader.readtext(img_cv, paragraph=False)
raw_boxes = []
for bbox, _, _ in results:
x_c, y_c = [pt[0] for pt in bbox], [pt[1] for pt in bbox]
raw_boxes.append({'x_min': min(x_c), 'x_max': max(x_c), 'y_min': min(y_c), 'y_max': max(y_c)})
if not raw_boxes: return []
raw_boxes.sort(key=lambda b: b['y_min'])
median_h = np.median([b['y_max'] - b['y_min'] for b in raw_boxes])
y_tol = median_h * 0.6
fused = []
for box in raw_boxes:
cy, placed = (box['y_min'] + box['y_max']) / 2.0, False
for line in fused:
if abs(cy - (line['y_min'] + line['y_max']) / 2.0) < y_tol:
line.update({'x_min': min(line['x_min'], box['x_min']), 'x_max': max(line['x_max'], box['x_max']), 'y_min': min(line['y_min'], box['y_min']), 'y_max': max(line['y_max'], box['y_max'])})
placed = True; break
if not placed: fused.append(box.copy())
crops = []
for line in sorted(fused, key=lambda b: b['y_min']):
crop = pil_img.crop((max(0, int(line['x_min']) - 20), max(0, int(line['y_min']) - 15), min(pil_img.width, int(line['x_max']) + 20), min(pil_img.height, int(line['y_max']) + 15)))
crops.append(ImageOps.expand(crop, border=40, fill=(255, 255, 255)))
return crops
def main():
col_t1, col_t2, col_t3 = st.columns([1, 8, 1])
with col_t2: st.markdown('
HandwrongingHandwriting OCR
', unsafe_allow_html=True)
with col_t3:
with st.popover("INFO"):
st.markdown("### 🧠 Forensic Neural Architecture")
st.write("This engine operates in a two-stage forensic sequence designed to maximize character fidelity. First, **EasyOCR** maps the image using mathematical line fusion, isolating text rows. Second, a **TrOCR Transformer** synthesizes the features into text. It may take a long time if ran online.")
st.markdown("---")
st.markdown("### ⚙️ The Neural Engines")
st.write("**Model V13 (Specialist):** I trained this specific model myself using the **IAM Handwriting Database** (over 65,000 instances). It is highly optimized for cursive loops and manual pen-strokes. It is excellent for handwritten manuscripts but might struggle with standard modern print.")
st.write("**Microsoft Large (1.3B Fallback):** A massive generalist model trained on millions of varied script and print examples. It is better for general use cases, complex historical documents, or heavily degraded text where V13 might struggle.")
if "image_data" not in st.session_state: st.session_state.update({"image_data": None, "ocr_results": None})
reader = load_vision_engine()
c_left, c_right = st.columns([1, 2], gap="large")
run_scan_trigger = False
with c_left:
# ─── THE REPOSITORY MAP ───
model_choice = st.selectbox("SELECT MODEL", ["V13 Specialist", "Microsoft Large"])
st.markdown("", unsafe_allow_html=True)
m_map = {
"V13 Specialist": "Hypernova823/ReadAI",
"Microsoft Large": "microsoft/trocr-large-handwritten"
}
if st.session_state.image_data is None:
st.markdown("""
add_a_photo
Initialize Data Input
BROWSE LOCAL STORAGE
""", unsafe_allow_html=True)
uploaded = st.file_uploader("Upload", type=['png', 'jpg', 'jpeg'], label_visibility="hidden")
if uploaded: st.session_state.image_data = Image.open(uploaded).convert("RGB"); st.rerun()
else:
st.image(st.session_state.image_data, width=350)
btn_col1, btn_col2 = st.columns(2)
with btn_col1:
if st.button("REMOVE IMAGE"):
st.session_state.update({"image_data": None, "ocr_results": None})
st.rerun()
with btn_col2:
if st.button("RUN NEURAL SCAN"):
run_scan_trigger = True
with c_right:
if run_scan_trigger:
start = time.time()
# The fully-styled, dark-mode spinner handles the wait time so the app doesn't freeze or refresh weirdly
with st.spinner("Allocating Neural Resources & Loading Weights..."):
proc, model, device = load_trocr_model(m_map[model_choice])
crops = extract_lines(st.session_state.image_data, reader)
decoded, scores = [], []
total_crops = len(crops)
if total_crops > 0:
# ─── DYNAMIC PROGRESS BAR INJECTION ───
progress_bar = st.progress(0, text="Initializing Neural Matrix...")
for idx, crop in enumerate(crops):
# Update Progress Text & Percentage
pct = int((idx / total_crops) * 100)
progress_bar.progress(pct, text=f"Synthesizing segment {idx + 1} out of {total_crops} | {pct}% Complete...")
pixel_values = proc(crop, return_tensors="pt").pixel_values.to(device)
if device.type == "cuda": pixel_values = pixel_values.half()
with torch.no_grad():
out = model.generate(pixel_values, max_new_tokens=64, return_dict_in_generate=True, output_scores=True)
decoded.append(proc.batch_decode(out.sequences, skip_special_tokens=True)[0].strip())
try: scores.extend(np.exp(model.compute_transition_scores(out.sequences, out.scores, normalize_logits=True)[0].cpu().numpy()))
except: pass
# Snap to 100% just before closing out
progress_bar.progress(100, text="Sequence Complete. Compiling output...")
time.sleep(0.3)
st.session_state.ocr_results = {"text": "\n".join(decoded), "time": time.time() - start, "words": len("\n".join(decoded).split()), "conf": np.mean(scores)*100 if scores else 0}
st.rerun()
elif st.session_state.ocr_results:
res = st.session_state.ocr_results
s1, s2, s3 = st.columns(3)
s1.markdown(f'{res["time"]:.1f}s
Latency
', unsafe_allow_html=True)
s2.markdown(f'', unsafe_allow_html=True)
s3.markdown(f'{res["conf"]:.1f}%
Confidence
', unsafe_allow_html=True)
st.markdown(f'{res["text"]}
', unsafe_allow_html=True)
tts = gTTS(text=res["text"], lang='en'); fp = io.BytesIO(); tts.write_to_fp(fp); fp.seek(0)
st.audio(fp, format='audio/mp3')
else:
st.markdown('AWAITING SCAN SEQUENCE...
', unsafe_allow_html=True)
if __name__ == "__main__": main()