vedaco commited on
Commit
06758b5
Β·
verified Β·
1 Parent(s): 1604c66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -73
app.py CHANGED
@@ -4,6 +4,7 @@ import threading
4
  import time
5
  import os
6
  import json
 
7
  from model import VedaProgrammingLLM
8
  from tokenizer import VedaTokenizer
9
  from database import db
@@ -15,106 +16,156 @@ model = None
15
  tokenizer = None
16
  current_id = -1
17
 
18
- # Initialize
 
 
 
 
 
 
 
 
 
19
  def init():
 
20
  global model, tokenizer
21
-
22
  conf_path = os.path.join(MODEL_DIR, "config.json")
23
  weights_path = os.path.join(MODEL_DIR, "weights.h5")
24
-
25
- if os.path.exists(weights_path) and os.path.exists(conf_path):
26
- with open(conf_path) as f: conf = json.load(f)
 
 
 
27
  tokenizer = VedaTokenizer()
28
- tokenizer.load(os.path.join(MODEL_DIR, "tokenizer.json"))
 
29
  model = VedaProgrammingLLM(**conf)
30
- model(tf.zeros((1, conf['max_length'])))
 
 
 
 
31
  model.load_weights(weights_path)
32
- else:
33
- print("Training initial model...")
34
- VedaTrainer().train(epochs=15)
35
- init()
36
 
37
- # Auto-train loop
38
- def auto_train():
 
 
 
 
 
 
39
  while True:
40
- time.sleep(300) # Check every 5 mins
41
  try:
42
  data = db.get_unused_distillation()
43
- if len(data) >= 5:
44
- print("Auto-training on teacher data...")
45
- text = "\n".join([f"<USER> {r[1]}\n<ASSISTANT> {r[2]}" for r in data])
46
- VedaTrainer().train(epochs=5, extra_data=text)
47
  db.mark_used([r[0] for r in data])
48
  init()
49
- except:
50
- pass
51
 
52
- threading.Thread(target=auto_train, daemon=True).start()
53
 
54
- def is_good(text):
55
- if not text or len(text) < 10: return False
56
- if "arr[" in text and "return" not in text: return False # Gibberish check
 
 
 
 
 
 
 
 
57
  return True
58
 
59
- def clean_response(text: str) -> str:
60
- if not text: return ""
61
- text = text.replace("<CODE>", "\n```python\n").replace("<ENDCODE>", "\n```\n")
62
- for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
63
- text = text.replace(token, "")
64
- if text.strip().startswith("```") and text.strip().endswith("```"):
65
- content = text.strip()[3:-3]
66
- if content.startswith("python"): content = content[6:]
67
- if not any(k in content for k in ["def ", "class ", "import ", "print(", "="]):
68
- text = content.strip()
69
- return text.strip()
70
 
71
- def respond(msg, history):
 
 
 
 
 
72
  global current_id
73
- if not msg.strip(): return "", history
74
-
75
- # Ensure history is a list
76
- if history is None: history = []
77
-
78
- # 1. Try student
79
- prompt = f"<USER> {msg}\n<ASSISTANT>"
 
 
 
80
  toks = tokenizer.encode(prompt)
81
- out = model.generate(toks, max_new_tokens=200)
82
- resp = tokenizer.decode(out).split("<ASSISTANT>")[-1].split("<USER>")[0].strip()
83
-
 
 
 
 
 
 
84
  resp = clean_response(resp)
85
 
86
- # 2. Check quality & fallback
87
- if not is_good(resp) and teacher.is_available():
88
- teacher_resp = teacher.ask(msg)
89
- if teacher_resp:
90
- resp = teacher_resp
91
- db.save_distillation(msg, teacher_resp) # Save for learning
92
-
93
- current_id = db.save_conversation(msg, resp)
94
-
95
- # FIX: Append list [user_msg, bot_msg] (Tuples format)
96
- # This matches the default Chatbot behavior (no type="messages")
97
- history.append([msg, resp])
98
-
 
 
 
99
  return "", history
100
 
101
- def feedback(vote):
102
- if current_id > 0: db.update_feedback(current_id, 1 if vote=="good" else -1)
103
 
104
- # UI
 
 
 
 
 
 
 
 
 
 
 
 
105
  init()
106
- with gr.Blocks(title="Veda") as demo:
 
 
107
  gr.Markdown("# πŸ•‰οΈ Veda Assistant")
108
-
109
- # FIX: Removed type="messages", relies on default list-of-lists
110
- chat = gr.Chatbot(height=400)
111
- msg = gr.Textbox(label="Message")
112
-
 
 
113
  with gr.Row():
114
- gr.Button("πŸ‘").click(lambda: feedback("good"))
115
- gr.Button("πŸ‘Ž").click(lambda: feedback("bad"))
116
-
117
- msg.submit(respond, [msg, chat], [msg, chat])
 
 
118
 
119
- # Launch
120
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
4
  import time
5
  import os
6
  import json
7
+
8
  from model import VedaProgrammingLLM
9
  from tokenizer import VedaTokenizer
10
  from database import db
 
16
  tokenizer = None
17
  current_id = -1
18
 
19
+
20
+ def clean_response(text: str) -> str:
21
+ if not text:
22
+ return ""
23
+ text = text.replace("<CODE>", "\n```python\n").replace("<ENDCODE>", "\n```\n")
24
+ for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
25
+ text = text.replace(token, "")
26
+ return text.strip()
27
+
28
+
29
  def init():
30
+ """Load model if exists else train once then load."""
31
  global model, tokenizer
32
+
33
  conf_path = os.path.join(MODEL_DIR, "config.json")
34
  weights_path = os.path.join(MODEL_DIR, "weights.h5")
35
+ tok_path = os.path.join(MODEL_DIR, "tokenizer.json")
36
+
37
+ if os.path.exists(weights_path) and os.path.exists(conf_path) and os.path.exists(tok_path):
38
+ with open(conf_path, "r") as f:
39
+ conf = json.load(f)
40
+
41
  tokenizer = VedaTokenizer()
42
+ tokenizer.load(tok_path)
43
+
44
  model = VedaProgrammingLLM(**conf)
45
+
46
+ # build model graph
47
+ max_len = conf.get("max_length", 512)
48
+ model(tf.zeros((1, max_len), dtype=tf.int32))
49
+
50
  model.load_weights(weights_path)
51
+ print("[Init] Model loaded.")
52
+ return
 
 
53
 
54
+ print("[Init] No model found -> Training initial model...")
55
+ VedaTrainer().train(epochs=10)
56
+ print("[Init] Training done -> Loading model...")
57
+ init()
58
+
59
+
60
+ def auto_train_loop():
61
+ """Background auto-train on teacher samples if available."""
62
  while True:
63
+ time.sleep(300) # 5 min
64
  try:
65
  data = db.get_unused_distillation()
66
+ if data and len(data) >= 5:
67
+ print(f"[AutoTrain] Training on {len(data)} teacher samples...")
68
+ extra = "\n".join([f"<USER> {r[1]}\n<ASSISTANT> {r[2]}" for r in data])
69
+ VedaTrainer().train(epochs=3, extra_data=extra)
70
  db.mark_used([r[0] for r in data])
71
  init()
72
+ except Exception as e:
73
+ print("[AutoTrain] skipped:", e)
74
 
 
75
 
76
+ def is_good(text: str) -> bool:
77
+ if not text:
78
+ return False
79
+ t = text.strip()
80
+ if len(t) < 20:
81
+ return False
82
+ # basic gibberish detectors
83
+ if "arr[" in t and "def " not in t and "return" not in t:
84
+ return False
85
+ if t.lower().count("hello how are you") >= 1:
86
+ return False
87
  return True
88
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ def respond(user_msg, history):
91
+ """
92
+ IMPORTANT: history must be LIST OF DICTS:
93
+ {"role":"user","content":"..."}
94
+ {"role":"assistant","content":"..."}
95
+ """
96
  global current_id
97
+
98
+ if history is None:
99
+ history = []
100
+
101
+ user_msg = (user_msg or "").strip()
102
+ if not user_msg:
103
+ return "", history
104
+
105
+ # Student response
106
+ prompt = f"<USER> {user_msg}\n<ASSISTANT>"
107
  toks = tokenizer.encode(prompt)
108
+ out_ids = model.generate(toks, max_new_tokens=200)
109
+ resp = tokenizer.decode(out_ids)
110
+
111
+ # Extract assistant section
112
+ if "<ASSISTANT>" in resp:
113
+ resp = resp.split("<ASSISTANT>")[-1]
114
+ if "<USER>" in resp:
115
+ resp = resp.split("<USER>")[0]
116
+
117
  resp = clean_response(resp)
118
 
119
+ # Teacher fallback
120
+ if (not is_good(resp)) and teacher.is_available():
121
+ t_resp = teacher.ask(user_msg)
122
+ if t_resp:
123
+ resp = t_resp
124
+ try:
125
+ db.save_distillation(user_msg, t_resp)
126
+ except Exception as e:
127
+ print("[DB] save_distillation failed:", e)
128
+
129
+ current_id = db.save_conversation(user_msg, resp)
130
+
131
+ # βœ… Messages format
132
+ history.append({"role": "user", "content": user_msg})
133
+ history.append({"role": "assistant", "content": resp})
134
+
135
  return "", history
136
 
 
 
137
 
138
+ def feedback_up():
139
+ if current_id > 0:
140
+ db.update_feedback(current_id, 1)
141
+ return "Saved πŸ‘"
142
+
143
+
144
+ def feedback_down():
145
+ if current_id > 0:
146
+ db.update_feedback(current_id, -1)
147
+ return "Saved πŸ‘Ž"
148
+
149
+
150
+ # --- startup ---
151
  init()
152
+ threading.Thread(target=auto_train_loop, daemon=True).start()
153
+
154
+ with gr.Blocks(title="Veda Assistant") as demo:
155
  gr.Markdown("# πŸ•‰οΈ Veda Assistant")
156
+
157
+ # DO NOT pass type= here (your Gradio rejects it)
158
+ chat = gr.Chatbot(height=400, value=[])
159
+
160
+ msg = gr.Textbox(label="Message", placeholder="Write bubble sort in python")
161
+ status = gr.Textbox(label="Status", interactive=False)
162
+
163
  with gr.Row():
164
+ up = gr.Button("πŸ‘")
165
+ down = gr.Button("πŸ‘Ž")
166
+
167
+ msg.submit(respond, inputs=[msg, chat], outputs=[msg, chat])
168
+ up.click(feedback_up, outputs=status)
169
+ down.click(feedback_down, outputs=status)
170
 
 
171
  demo.launch(server_name="0.0.0.0", server_port=7860)