vortexa64 commited on
Commit
59e524d
·
verified ·
1 Parent(s): d1a5836

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ import json
4
+ import torch
5
+
6
+ # Load vocab
7
+ with open("vocab.json", "r", encoding="utf-8") as f:
8
+ vocab = json.load(f)
9
+ inv_vocab = {i:tok for tok,i in vocab.items()}
10
+
11
+ pad_idx = vocab.get("<pad>", 0)
12
+ unk_idx = vocab.get("<unk>", 1)
13
+
14
+ # Load ONNX model
15
+ session = ort.InferenceSession("chat_model.onnx")
16
+
17
+ # Tokenizer helper
18
+ def tokenize(text):
19
+ return [vocab.get(tok, unk_idx) for tok in text.split(" ")]
20
+
21
+ def pad_sequence(seq, max_len=20):
22
+ seq = seq + [pad_idx]*(max_len - len(seq))
23
+ return seq[:max_len]
24
+
25
+ # Fungsi chat mini
26
+ def chat_onnx(input_text, max_len=20):
27
+ input_ids = pad_sequence(tokenize(input_text), max_len)
28
+ input_tensor = np.array([input_ids], dtype=np.int64)
29
+
30
+ output_ids = []
31
+ h = None # ONNX simple RNN ini biasanya stateless
32
+ for _ in range(max_len):
33
+ ort_inputs = {"input": input_tensor}
34
+ ort_outs = session.run(None, ort_inputs)
35
+ next_token = int(ort_outs[0][0, -1].argmax())
36
+ output_ids.append(next_token)
37
+ input_tensor = np.array([[next_token]], dtype=np.int64)
38
+
39
+ return " ".join([inv_vocab.get(i, "<unk>") for i in output_ids])
40
+
41
+ # Gradio interface
42
+ iface = gr.Interface(fn=chat_onnx, inputs="text", outputs="text")
43
+ iface.launch()