Spaces:
Sleeping
Sleeping
| """ | |
| Hugging Face Space — Vietnamese Sign Language Recognition | |
| Upload file .npz -> nhận diện từ ký hiệu -> dịch sang câu tiếng Việt | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| MODEL_REPO = "huy00001/vsl-recognition" | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "model")) | |
| from stgcn_ctc import STGCNCTCModel | |
| DEVICE = "cpu" | |
| _pipeline = None | |
| def load_pipeline(): | |
| global _pipeline | |
| if _pipeline is not None: | |
| return _pipeline | |
| # label_map | |
| label_map_path = hf_hub_download(repo_id=MODEL_REPO, filename="label_map.json") | |
| with open(label_map_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| classes = data["classes"] | |
| num_classes = len(classes) | |
| # CTC model | |
| ctc_path = hf_hub_download(repo_id=MODEL_REPO, filename="checkpoints/best_ctc.pt") | |
| ctc_model = STGCNCTCModel(num_classes=num_classes) | |
| ckpt = torch.load(ctc_path, map_location=DEVICE) | |
| ctc_model.load_state_dict(ckpt["model_state"]) | |
| ctc_model.eval() | |
| # seq2seq model | |
| seq2seq_dir = snapshot_download(repo_id=MODEL_REPO, allow_patterns="checkpoints/best_seq2seq/*") | |
| seq2seq_path = os.path.join(seq2seq_dir, "checkpoints", "best_seq2seq") | |
| tokenizer = AutoTokenizer.from_pretrained(seq2seq_path) | |
| seq2seq = AutoModelForSeq2SeqLM.from_pretrained(seq2seq_path) | |
| seq2seq.eval() | |
| _pipeline = {"ctc": ctc_model, "tokenizer": tokenizer, | |
| "seq2seq": seq2seq, "classes": classes, "num_classes": num_classes} | |
| return _pipeline | |
| def greedy_decode(log_probs, blank=0): | |
| indices = log_probs.argmax(dim=-1).tolist() | |
| decoded, prev = [], blank | |
| for idx in indices: | |
| if idx != blank and idx != prev: | |
| decoded.append(idx) | |
| prev = idx | |
| return decoded | |
| def predict(npz_file): | |
| if npz_file is None: | |
| return "Chưa có file", "" | |
| try: | |
| p = load_pipeline() | |
| seq = np.load(npz_file, allow_pickle=True)["sequence"].astype(np.float32) | |
| x = torch.from_numpy(seq).unsqueeze(0) | |
| with torch.no_grad(): | |
| log_probs, _ = p["ctc"](x) | |
| log_probs = log_probs[:, 0, :] | |
| indices = greedy_decode(log_probs) | |
| words = [p["classes"][i - 1] for i in indices if 1 <= i <= p["num_classes"]] | |
| if not words: | |
| return "Không nhận diện được từ nào", "" | |
| text = " | ".join(words) | |
| enc = p["tokenizer"](text, return_tensors="pt", truncation=True, max_length=64) | |
| with torch.no_grad(): | |
| out_ids = p["seq2seq"].generate(**enc, max_new_tokens=128, num_beams=4) | |
| sentence = p["tokenizer"].decode(out_ids[0], skip_special_tokens=True) | |
| return text, sentence | |
| except Exception as e: | |
| return f"Lỗi: {str(e)}", "" | |
| with gr.Blocks(title="Nhận diện ngôn ngữ ký hiệu tiếng Việt") as demo: | |
| gr.Markdown("## 🤟 Nhận diện ngôn ngữ ký hiệu tiếng Việt") | |
| gr.Markdown("Upload file `.npz` chứa skeleton sequence để nhận diện.") | |
| file_input = gr.File(label="File .npz", file_types=[".npz"]) | |
| btn = gr.Button("Nhận diện", variant="primary") | |
| words_out = gr.Textbox(label="Từ ký hiệu nhận diện được") | |
| sentence_out = gr.Textbox(label="Câu tiếng Việt") | |
| btn.click(fn=predict, inputs=file_input, outputs=[words_out, sentence_out]) | |
| if __name__ == "__main__": | |
| demo.launch() | |