signlanguage / app.py
huy00001's picture
Upload app.py with huggingface_hub
6f4e51f verified
"""
Hugging Face Space — Vietnamese Sign Language Recognition
Upload file .npz -> nhận diện từ ký hiệu -> dịch sang câu tiếng Việt
"""
import os
import sys
import json
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_REPO = "huy00001/vsl-recognition"
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "model"))
from stgcn_ctc import STGCNCTCModel
DEVICE = "cpu"
_pipeline = None
def load_pipeline():
global _pipeline
if _pipeline is not None:
return _pipeline
# label_map
label_map_path = hf_hub_download(repo_id=MODEL_REPO, filename="label_map.json")
with open(label_map_path, "r", encoding="utf-8") as f:
data = json.load(f)
classes = data["classes"]
num_classes = len(classes)
# CTC model
ctc_path = hf_hub_download(repo_id=MODEL_REPO, filename="checkpoints/best_ctc.pt")
ctc_model = STGCNCTCModel(num_classes=num_classes)
ckpt = torch.load(ctc_path, map_location=DEVICE)
ctc_model.load_state_dict(ckpt["model_state"])
ctc_model.eval()
# seq2seq model
seq2seq_dir = snapshot_download(repo_id=MODEL_REPO, allow_patterns="checkpoints/best_seq2seq/*")
seq2seq_path = os.path.join(seq2seq_dir, "checkpoints", "best_seq2seq")
tokenizer = AutoTokenizer.from_pretrained(seq2seq_path)
seq2seq = AutoModelForSeq2SeqLM.from_pretrained(seq2seq_path)
seq2seq.eval()
_pipeline = {"ctc": ctc_model, "tokenizer": tokenizer,
"seq2seq": seq2seq, "classes": classes, "num_classes": num_classes}
return _pipeline
def greedy_decode(log_probs, blank=0):
indices = log_probs.argmax(dim=-1).tolist()
decoded, prev = [], blank
for idx in indices:
if idx != blank and idx != prev:
decoded.append(idx)
prev = idx
return decoded
def predict(npz_file):
if npz_file is None:
return "Chưa có file", ""
try:
p = load_pipeline()
seq = np.load(npz_file, allow_pickle=True)["sequence"].astype(np.float32)
x = torch.from_numpy(seq).unsqueeze(0)
with torch.no_grad():
log_probs, _ = p["ctc"](x)
log_probs = log_probs[:, 0, :]
indices = greedy_decode(log_probs)
words = [p["classes"][i - 1] for i in indices if 1 <= i <= p["num_classes"]]
if not words:
return "Không nhận diện được từ nào", ""
text = " | ".join(words)
enc = p["tokenizer"](text, return_tensors="pt", truncation=True, max_length=64)
with torch.no_grad():
out_ids = p["seq2seq"].generate(**enc, max_new_tokens=128, num_beams=4)
sentence = p["tokenizer"].decode(out_ids[0], skip_special_tokens=True)
return text, sentence
except Exception as e:
return f"Lỗi: {str(e)}", ""
with gr.Blocks(title="Nhận diện ngôn ngữ ký hiệu tiếng Việt") as demo:
gr.Markdown("## 🤟 Nhận diện ngôn ngữ ký hiệu tiếng Việt")
gr.Markdown("Upload file `.npz` chứa skeleton sequence để nhận diện.")
file_input = gr.File(label="File .npz", file_types=[".npz"])
btn = gr.Button("Nhận diện", variant="primary")
words_out = gr.Textbox(label="Từ ký hiệu nhận diện được")
sentence_out = gr.Textbox(label="Câu tiếng Việt")
btn.click(fn=predict, inputs=file_input, outputs=[words_out, sentence_out])
if __name__ == "__main__":
demo.launch()