File size: 3,715 Bytes
146225d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f4e51f
146225d
 
 
 
 
 
 
6f4e51f
146225d
 
 
6f4e51f
146225d
 
6f4e51f
 
146225d
6f4e51f
146225d
 
 
6f4e51f
 
146225d
6f4e51f
 
146225d
 
6f4e51f
 
146225d
 
 
6f4e51f
146225d
 
 
 
 
 
 
 
 
 
 
 
 
6f4e51f
a128d97
146225d
6f4e51f
146225d
 
 
6f4e51f
146225d
 
 
 
 
 
6f4e51f
 
146225d
6f4e51f
146225d
 
6f4e51f
146225d
 
 
 
 
6f4e51f
146225d
6f4e51f
 
 
 
146225d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""

Hugging Face Space — Vietnamese Sign Language Recognition

Upload file .npz -> nhận diện từ ký hiệu -> dịch sang câu tiếng Việt

"""
import os
import sys
import json
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

MODEL_REPO = "huy00001/vsl-recognition"

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "model"))
from stgcn_ctc import STGCNCTCModel

DEVICE = "cpu"
_pipeline = None


def load_pipeline():
    global _pipeline
    if _pipeline is not None:
        return _pipeline

    # label_map
    label_map_path = hf_hub_download(repo_id=MODEL_REPO, filename="label_map.json")
    with open(label_map_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    classes     = data["classes"]
    num_classes = len(classes)

    # CTC model
    ctc_path  = hf_hub_download(repo_id=MODEL_REPO, filename="checkpoints/best_ctc.pt")
    ctc_model = STGCNCTCModel(num_classes=num_classes)
    ckpt      = torch.load(ctc_path, map_location=DEVICE)
    ctc_model.load_state_dict(ckpt["model_state"])
    ctc_model.eval()

    # seq2seq model
    seq2seq_dir  = snapshot_download(repo_id=MODEL_REPO, allow_patterns="checkpoints/best_seq2seq/*")
    seq2seq_path = os.path.join(seq2seq_dir, "checkpoints", "best_seq2seq")
    tokenizer    = AutoTokenizer.from_pretrained(seq2seq_path)
    seq2seq      = AutoModelForSeq2SeqLM.from_pretrained(seq2seq_path)
    seq2seq.eval()

    _pipeline = {"ctc": ctc_model, "tokenizer": tokenizer,
                 "seq2seq": seq2seq, "classes": classes, "num_classes": num_classes}
    return _pipeline


def greedy_decode(log_probs, blank=0):
    indices = log_probs.argmax(dim=-1).tolist()
    decoded, prev = [], blank
    for idx in indices:
        if idx != blank and idx != prev:
            decoded.append(idx)
        prev = idx
    return decoded


def predict(npz_file):
    if npz_file is None:
        return "Chưa có file", ""
    try:
        p   = load_pipeline()
        seq = np.load(npz_file, allow_pickle=True)["sequence"].astype(np.float32)
        x   = torch.from_numpy(seq).unsqueeze(0)

        with torch.no_grad():
            log_probs, _ = p["ctc"](x)
            log_probs    = log_probs[:, 0, :]

        indices = greedy_decode(log_probs)
        words   = [p["classes"][i - 1] for i in indices if 1 <= i <= p["num_classes"]]

        if not words:
            return "Không nhận diện được từ nào", ""

        text    = " | ".join(words)
        enc     = p["tokenizer"](text, return_tensors="pt", truncation=True, max_length=64)
        with torch.no_grad():
            out_ids  = p["seq2seq"].generate(**enc, max_new_tokens=128, num_beams=4)
        sentence = p["tokenizer"].decode(out_ids[0], skip_special_tokens=True)

        return text, sentence
    except Exception as e:
        return f"Lỗi: {str(e)}", ""


with gr.Blocks(title="Nhận diện ngôn ngữ ký hiệu tiếng Việt") as demo:
    gr.Markdown("## 🤟 Nhận diện ngôn ngữ ký hiệu tiếng Việt")
    gr.Markdown("Upload file `.npz` chứa skeleton sequence để nhận diện.")
    file_input   = gr.File(label="File .npz", file_types=[".npz"])
    btn          = gr.Button("Nhận diện", variant="primary")
    words_out    = gr.Textbox(label="Từ ký hiệu nhận diện được")
    sentence_out = gr.Textbox(label="Câu tiếng Việt")
    btn.click(fn=predict, inputs=file_input, outputs=[words_out, sentence_out])

if __name__ == "__main__":
    demo.launch()