asl-translator / streamlit_app.py
tanmmayyy's picture
Update streamlit_app.py
c8c3ed0 verified
import streamlit as st
import cv2
import torch
import torch.nn as nn
import numpy as np
import mediapipe as mp
import pickle
from collections import deque
import PIL.Image
st.set_page_config(page_title="ASL Translator", page_icon="🀟", layout="wide")
st.title("🀟 ASL Sign Language Translator")
st.markdown("Show your hand to the camera β€” hold a sign steady to add it to the sentence.")
# ── Load model ────────────────────────────────────────────
@st.cache_resource
def load_model():
with open("label_encoder.pkl", "rb") as f:
le = pickle.load(f)
class ASLClassifier(nn.Module):
def __init__(self, input_dim=63, num_classes=28):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(64, num_classes)
)
def forward(self, x):
return self.net(x)
model = ASLClassifier(num_classes=len(le.classes_))
model.load_state_dict(torch.load("asl_model_best.pth", map_location="cpu", weights_only=True))
model.eval()
return model, le
model, le = load_model()
# ── MediaPipe ─────────────────────────────────────────────
@st.cache_resource
def load_hands():
mp_hands = mp.solutions.hands
return mp_hands.Hands(
static_image_mode=False,
max_num_hands=1,
min_detection_confidence=0.7,
min_tracking_confidence=0.7
), mp_hands
hands, mp_hands = load_hands()
mp_drawing = mp.solutions.drawing_utils
# ── Sidebar ───────────────────────────────────────────────
st.sidebar.header("Settings")
hold_frames = st.sidebar.slider("Hold frames to confirm", 10, 40, 20)
min_confidence = st.sidebar.slider("Min confidence", 0.5, 1.0, 0.75)
st.sidebar.markdown("---")
st.sidebar.markdown("**How to use:**")
st.sidebar.markdown("- Hold a sign steady β†’ letter added")
st.sidebar.markdown("- Sign `del` β†’ delete last letter")
st.sidebar.markdown("- Sign `space` β†’ add space")
st.sidebar.markdown("- Click **Clear** to reset sentence")
# ── Session state ─────────────────────────────────────────
for key, val in [("sentence", ""), ("last_letter", ""), ("hold_count", 0)]:
if key not in st.session_state:
st.session_state[key] = val
pred_buffer = deque(maxlen=7)
# ══════════════════════════════════════════════════════════
# TAB 1: Live webcam TAB 2: Upload image
# ══════════════════════════════════════════════════════════
tab1, tab2 = st.tabs(["Live webcam", "Upload image"])
# ── TAB 1: Webcam ─────────────────────────────────────────
with tab1:
col_cam, col_info = st.columns([2, 1])
with col_cam:
run = st.checkbox("Start camera", value=False)
FRAME_WINDOW = st.image([])
with col_info:
st.markdown("### Current sign")
sign_display = st.empty()
conf_display = st.empty()
st.markdown("### Sentence")
sentence_display = st.empty()
if st.button("Clear sentence"):
st.session_state.sentence = ""
st.session_state.last_letter = ""
st.session_state.hold_count = 0
st.markdown("---")
st.markdown("### Model info")
st.success("99.22% accuracy")
st.info("28 classes Β· A–Z + del + space")
st.info("63,673 training samples")
cap = None
if run:
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
while run:
ret, frame = cap.read()
if not ret:
st.error("Camera not found.")
break
frame = cv2.flip(frame, 1)
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
result = hands.process(rgb)
smoothed = ""
confidence = 0.0
if result.multi_hand_landmarks:
lm = result.multi_hand_landmarks[0].landmark
wx, wy, wz = lm[0].x, lm[0].y, lm[0].z
coords = []
for point in lm:
coords.extend([point.x - wx, point.y - wy, point.z - wz])
x_tensor = torch.tensor([coords], dtype=torch.float32)
with torch.no_grad():
logits = model(x_tensor)
probs = torch.softmax(logits, dim=1)
conf, pred = probs.max(dim=1)
label = le.inverse_transform(pred.numpy())[0]
confidence = conf.item()
if confidence >= min_confidence:
pred_buffer.append(label)
smoothed = max(set(pred_buffer), key=pred_buffer.count)
if smoothed == st.session_state.last_letter:
st.session_state.hold_count += 1
else:
st.session_state.hold_count = 0
st.session_state.last_letter = smoothed
if st.session_state.hold_count == hold_frames:
if smoothed == "del":
st.session_state.sentence = st.session_state.sentence[:-1]
elif smoothed == "space":
st.session_state.sentence += " "
else:
st.session_state.sentence += smoothed
st.session_state.hold_count = 0
mp_drawing.draw_landmarks(
frame, result.multi_hand_landmarks[0],
mp_hands.HAND_CONNECTIONS
)
cv2.rectangle(frame, (10, 10), (300, 80), (0, 0, 0), -1)
cv2.putText(frame, f"Sign: {smoothed}", (20, 45),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 120), 2)
cv2.putText(frame, f"Conf: {confidence:.2f}", (20, 70),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (180, 180, 180), 1)
bar_w = int((st.session_state.hold_count / hold_frames) * 250)
cv2.rectangle(frame, (10, 88), (260, 102), (50, 50, 50), -1)
cv2.rectangle(frame, (10, 88), (10 + bar_w, 102), (0, 255, 120), -1)
else:
cv2.putText(frame, "No hand detected", (20, 45),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 255), 2)
h, w = frame.shape[:2]
cv2.rectangle(frame, (0, h - 45), (w, h), (0, 0, 0), -1)
cv2.putText(frame, f"{st.session_state.sentence or '...'}", (10, h - 12),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 255), 2)
FRAME_WINDOW.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
sign_display.markdown(
f"<h1 style='color:#00ff88;font-size:64px;margin:0'>{smoothed or 'β€”'}</h1>",
unsafe_allow_html=True
)
conf_display.markdown(
f"<p style='color:gray'>Confidence: {confidence:.2%}</p>",
unsafe_allow_html=True
)
sentence_display.markdown(
f"<div style='font-size:22px;padding:10px;background:#1e1e1e;"
f"color:#00ff88;border-radius:8px;font-family:monospace;min-height:50px'>"
f"{st.session_state.sentence or '...'}</div>",
unsafe_allow_html=True
)
if cap:
cap.release()
# ── TAB 2: Image upload ───────────────────────────────────
with tab2:
st.markdown("### Test with an image")
st.markdown("Upload a photo of a hand making an ASL sign β€” works great for testing on Hugging Face.")
uploaded = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
if uploaded:
img_pil = PIL.Image.open(uploaded).convert("RGB")
img_rgb = np.array(img_pil)
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
# Use static mode for single image
hands_static = mp.solutions.hands.Hands(
static_image_mode=True,
max_num_hands=1,
min_detection_confidence=0.5
)
result = hands_static.process(img_rgb)
hands_static.close()
col_img, col_result = st.columns([1, 1])
with col_img:
if result.multi_hand_landmarks:
mp_drawing.draw_landmarks(
img_bgr,
result.multi_hand_landmarks[0],
mp_hands.HAND_CONNECTIONS
)
st.image(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB), caption="Uploaded image", use_container_width=True)
with col_result:
if result.multi_hand_landmarks:
lm = result.multi_hand_landmarks[0].landmark
wx, wy, wz = lm[0].x, lm[0].y, lm[0].z
coords = []
for point in lm:
coords.extend([point.x - wx, point.y - wy, point.z - wz])
x_tensor = torch.tensor([coords], dtype=torch.float32)
with torch.no_grad():
logits = model(x_tensor)
probs = torch.softmax(logits, dim=1)
conf, pred = probs.max(dim=1)
label = le.inverse_transform(pred.numpy())[0]
confidence = conf.item()
# Top 3 predictions
top3_conf, top3_idx = probs[0].topk(3)
top3_labels = le.inverse_transform(top3_idx.numpy())
st.markdown("### Prediction")
st.markdown(
f"<h1 style='color:#00ff88;font-size:80px;margin:0'>{label}</h1>",
unsafe_allow_html=True
)
st.markdown(f"**Confidence:** {confidence:.2%}")
st.markdown("---")
st.markdown("**Top 3 predictions:**")
for lbl, cf in zip(top3_labels, top3_conf):
st.progress(float(cf), text=f"{lbl} β€” {cf:.2%}")
else:
st.warning("No hand detected. Try a clearer image with better lighting.")
st.markdown("**Tips:**")
st.markdown("- Make sure your hand is clearly visible")
st.markdown("- Good lighting helps a lot")
st.markdown("- Try a plain background")