huy00001 commited on
Commit
6f4e51f
·
verified ·
1 Parent(s): fcff481

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +25 -58
app.py CHANGED
@@ -1,8 +1,6 @@
1
  """
2
  Hugging Face Space — Vietnamese Sign Language Recognition
3
  Upload file .npz -> nhận diện từ ký hiệu -> dịch sang câu tiếng Việt
4
-
5
- Gradio UI chạy trên HF Spaces (CPU).
6
  """
7
  import os
8
  import sys
@@ -13,16 +11,13 @@ import gradio as gr
13
  from huggingface_hub import hf_hub_download, snapshot_download
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
15
 
16
- # ── Thay bằng repo của bạn ─────────────────────────────────────
17
  MODEL_REPO = "huy00001/vsl-recognition"
18
- # ──────────────────────────────────────────────────────────────
19
 
20
- # Thêm model code vào path
21
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), "model"))
22
  from stgcn_ctc import STGCNCTCModel
23
 
24
  DEVICE = "cpu"
25
- _pipeline = None # lazy load
26
 
27
 
28
  def load_pipeline():
@@ -30,45 +25,33 @@ def load_pipeline():
30
  if _pipeline is not None:
31
  return _pipeline
32
 
33
- print("Downloading model files from HF Hub...")
34
-
35
- # Download label_map.json
36
  label_map_path = hf_hub_download(repo_id=MODEL_REPO, filename="label_map.json")
37
  with open(label_map_path, "r", encoding="utf-8") as f:
38
  data = json.load(f)
39
- label2idx = data["label2idx"]
40
- classes = data["classes"]
41
  num_classes = len(classes)
42
 
43
- # Download CTC checkpoint
44
- ctc_path = hf_hub_download(repo_id=MODEL_REPO, filename="checkpoints/best_ctc.pt")
45
  ctc_model = STGCNCTCModel(num_classes=num_classes)
46
- ckpt = torch.load(ctc_path, map_location=DEVICE)
47
  ctc_model.load_state_dict(ckpt["model_state"])
48
  ctc_model.eval()
49
 
50
- # Download seq2seq model
51
- seq2seq_dir = snapshot_download(
52
- repo_id=MODEL_REPO,
53
- allow_patterns="checkpoints/best_seq2seq/*",
54
- )
55
  seq2seq_path = os.path.join(seq2seq_dir, "checkpoints", "best_seq2seq")
56
- tokenizer = AutoTokenizer.from_pretrained(seq2seq_path)
57
- seq2seq = AutoModelForSeq2SeqLM.from_pretrained(seq2seq_path)
58
  seq2seq.eval()
59
 
60
- _pipeline = {
61
- "ctc": ctc_model,
62
- "tokenizer": tokenizer,
63
- "seq2seq": seq2seq,
64
- "classes": classes,
65
- "num_classes": num_classes,
66
- }
67
- print("Pipeline loaded.")
68
  return _pipeline
69
 
70
 
71
- def greedy_decode(log_probs: torch.Tensor, blank: int = 0) -> list[int]:
72
  indices = log_probs.argmax(dim=-1).tolist()
73
  decoded, prev = [], blank
74
  for idx in indices:
@@ -79,58 +62,42 @@ def greedy_decode(log_probs: torch.Tensor, blank: int = 0) -> list[int]:
79
 
80
 
81
  def predict(npz_file):
82
- """Gradio handler: nhận file .npz, trả về (words, sentence)"""
83
  if npz_file is None:
84
  return "Chưa có file", ""
85
-
86
  try:
87
- p = load_pipeline()
88
-
89
- # Stage 1b — CTC
90
  seq = np.load(npz_file, allow_pickle=True)["sequence"].astype(np.float32)
91
  x = torch.from_numpy(seq).unsqueeze(0)
 
92
  with torch.no_grad():
93
  log_probs, _ = p["ctc"](x)
94
  log_probs = log_probs[:, 0, :]
 
95
  indices = greedy_decode(log_probs)
96
  words = [p["classes"][i - 1] for i in indices if 1 <= i <= p["num_classes"]]
97
 
98
  if not words:
99
  return "Không nhận diện được từ nào", ""
100
 
101
- # Stage 2 — seq2seq
102
- text = " | ".join(words)
103
- enc = p["tokenizer"](text, return_tensors="pt", truncation=True, max_length=64)
104
  with torch.no_grad():
105
- out_ids = p["seq2seq"].generate(**enc, max_new_tokens=128, num_beams=4)
106
  sentence = p["tokenizer"].decode(out_ids[0], skip_special_tokens=True)
107
 
108
- return " | ".join(words), sentence
109
-
110
  except Exception as e:
111
  return f"Lỗi: {str(e)}", ""
112
 
113
 
114
- # ── Gradio UI ─────────────────────────────────────────────────��
115
  with gr.Blocks(title="Nhận diện ngôn ngữ ký hiệu tiếng Việt") as demo:
116
- gr.Markdown("## Nhận diện ngôn ngữ ký hiệu tiếng Việt")
117
  gr.Markdown("Upload file `.npz` chứa skeleton sequence để nhận diện.")
118
-
119
- with gr.Row():
120
- file_input = gr.File(label="File .npz", file_types=[".npz"], type="filepath")
121
-
122
- btn = gr.Button("Nhận diện", variant="primary")
123
-
124
- with gr.Row():
125
- words_out = gr.Textbox(label="Từ ký hiệu nhận diện được")
126
- sentence_out = gr.Textbox(label="Câu tiếng Việt")
127
-
128
  btn.click(fn=predict, inputs=file_input, outputs=[words_out, sentence_out])
129
 
130
- gr.Examples(
131
- examples=[], # thêm file .npz mẫu vào đây nếu muốn
132
- inputs=file_input,
133
- )
134
-
135
  if __name__ == "__main__":
136
  demo.launch()
 
1
  """
2
  Hugging Face Space — Vietnamese Sign Language Recognition
3
  Upload file .npz -> nhận diện từ ký hiệu -> dịch sang câu tiếng Việt
 
 
4
  """
5
  import os
6
  import sys
 
11
  from huggingface_hub import hf_hub_download, snapshot_download
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
 
 
14
  MODEL_REPO = "huy00001/vsl-recognition"
 
15
 
 
16
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), "model"))
17
  from stgcn_ctc import STGCNCTCModel
18
 
19
  DEVICE = "cpu"
20
+ _pipeline = None
21
 
22
 
23
  def load_pipeline():
 
25
  if _pipeline is not None:
26
  return _pipeline
27
 
28
+ # label_map
 
 
29
  label_map_path = hf_hub_download(repo_id=MODEL_REPO, filename="label_map.json")
30
  with open(label_map_path, "r", encoding="utf-8") as f:
31
  data = json.load(f)
32
+ classes = data["classes"]
 
33
  num_classes = len(classes)
34
 
35
+ # CTC model
36
+ ctc_path = hf_hub_download(repo_id=MODEL_REPO, filename="checkpoints/best_ctc.pt")
37
  ctc_model = STGCNCTCModel(num_classes=num_classes)
38
+ ckpt = torch.load(ctc_path, map_location=DEVICE)
39
  ctc_model.load_state_dict(ckpt["model_state"])
40
  ctc_model.eval()
41
 
42
+ # seq2seq model
43
+ seq2seq_dir = snapshot_download(repo_id=MODEL_REPO, allow_patterns="checkpoints/best_seq2seq/*")
 
 
 
44
  seq2seq_path = os.path.join(seq2seq_dir, "checkpoints", "best_seq2seq")
45
+ tokenizer = AutoTokenizer.from_pretrained(seq2seq_path)
46
+ seq2seq = AutoModelForSeq2SeqLM.from_pretrained(seq2seq_path)
47
  seq2seq.eval()
48
 
49
+ _pipeline = {"ctc": ctc_model, "tokenizer": tokenizer,
50
+ "seq2seq": seq2seq, "classes": classes, "num_classes": num_classes}
 
 
 
 
 
 
51
  return _pipeline
52
 
53
 
54
+ def greedy_decode(log_probs, blank=0):
55
  indices = log_probs.argmax(dim=-1).tolist()
56
  decoded, prev = [], blank
57
  for idx in indices:
 
62
 
63
 
64
  def predict(npz_file):
 
65
  if npz_file is None:
66
  return "Chưa có file", ""
 
67
  try:
68
+ p = load_pipeline()
 
 
69
  seq = np.load(npz_file, allow_pickle=True)["sequence"].astype(np.float32)
70
  x = torch.from_numpy(seq).unsqueeze(0)
71
+
72
  with torch.no_grad():
73
  log_probs, _ = p["ctc"](x)
74
  log_probs = log_probs[:, 0, :]
75
+
76
  indices = greedy_decode(log_probs)
77
  words = [p["classes"][i - 1] for i in indices if 1 <= i <= p["num_classes"]]
78
 
79
  if not words:
80
  return "Không nhận diện được từ nào", ""
81
 
82
+ text = " | ".join(words)
83
+ enc = p["tokenizer"](text, return_tensors="pt", truncation=True, max_length=64)
 
84
  with torch.no_grad():
85
+ out_ids = p["seq2seq"].generate(**enc, max_new_tokens=128, num_beams=4)
86
  sentence = p["tokenizer"].decode(out_ids[0], skip_special_tokens=True)
87
 
88
+ return text, sentence
 
89
  except Exception as e:
90
  return f"Lỗi: {str(e)}", ""
91
 
92
 
 
93
  with gr.Blocks(title="Nhận diện ngôn ngữ ký hiệu tiếng Việt") as demo:
94
+ gr.Markdown("## 🤟 Nhận diện ngôn ngữ ký hiệu tiếng Việt")
95
  gr.Markdown("Upload file `.npz` chứa skeleton sequence để nhận diện.")
96
+ file_input = gr.File(label="File .npz", file_types=[".npz"])
97
+ btn = gr.Button("Nhận diện", variant="primary")
98
+ words_out = gr.Textbox(label="Từ hiệu nhận diện được")
99
+ sentence_out = gr.Textbox(label="Câu tiếng Việt")
 
 
 
 
 
 
100
  btn.click(fn=predict, inputs=file_input, outputs=[words_out, sentence_out])
101
 
 
 
 
 
 
102
  if __name__ == "__main__":
103
  demo.launch()