gijl commited on
Commit
f5c5361
·
verified ·
1 Parent(s): 60287a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -30
app.py CHANGED
@@ -5,72 +5,75 @@ import torch.nn.functional as F
5
  from huggingface_hub import hf_hub_download
6
  from model import MedicalMasterAI
7
 
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
9
 
 
10
  with open("tokenizer_config.json", "r", encoding="utf-8") as f:
11
  vocab = json.load(f)
12
-
13
  stoi = vocab["stoi"]
14
  itos = vocab["itos"]
15
 
16
  def encode(text):
17
- return [stoi.get(c, 0) for c in text]
18
 
19
  def decode(ids):
20
  return "".join([itos.get(str(i), "") for i in ids])
21
 
 
22
  try:
23
  model = MedicalMasterAI(vocab_size=115, n_layer=48, n_head=8, n_embd=768)
24
-
25
- print("جاري سحب ملف الأوزان من المستودع...")
26
  model_path = hf_hub_download(repo_id="gijl/Medical-Master-1.5B", filename="pytorch_model.bin")
27
- print("تم التحميل بنجاح. جاري قراءة الأوزان...")
28
-
29
  state_dict = torch.load(model_path, map_location=device, weights_only=True)
30
-
31
- # إضافة strict=False لتجاهل طبقات الصور (image_projection) بأمان
32
  model.load_state_dict(state_dict, strict=False)
33
-
34
- model.to(device)
35
- model.eval()
36
  model_loaded = True
37
- print("النموذج جاهز للعمل!")
38
  except Exception as e:
39
- print(f"Error loading model: {e}")
40
  model_loaded = False
41
 
 
42
  def medical_chat(message, history):
43
  if not model_loaded:
44
- return "حدث خطأ أثناء تحميل أوزان النموذج. يرجى مراجعة السجلات."
45
-
46
- prompt = f"Question: {message} Answer:"
 
 
47
  idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
48
 
49
- max_new_chars = 200
50
- generated_ids = []
51
 
52
- with torch.no_grad():
53
- for _ in range(max_new_chars):
54
- # تغيير 1024 إلى 256 ليتطابق مع حجم أوزان التدريب
 
55
  idx_cond = idx[:, -256:]
56
 
57
  logits = model(idx_cond)
58
- logits = logits[:, -1, :]
59
- temperature = 0.8
60
- logits = logits / temperature
61
  probs = F.softmax(logits, dim=-1)
62
  idx_next = torch.multinomial(probs, num_samples=1)
63
 
 
64
  idx = torch.cat((idx, idx_next), dim=1)
65
- generated_ids.append(idx_next.item())
66
 
67
- answer = decode(generated_ids)
68
- return answer
 
 
 
 
 
 
 
69
 
 
70
  demo = gr.ChatInterface(
71
  fn=medical_chat,
72
- title="Medical Master (Custom PyTorch AI)",
73
- description="نموذج مبني من الصفر للإجابة على الاستفسارات (يعمل بالوضع النصي حالياً).",
74
  )
75
 
76
  if __name__ == "__main__":
 
5
  from huggingface_hub import hf_hub_download
6
  from model import MedicalMasterAI
7
 
8
+ # إعداد الجهاز
9
+ device = torch.device("cpu") # المساحات المجانية تستخدم المعالج
10
 
11
+ # 1. تحميل التوكنايزر
12
  with open("tokenizer_config.json", "r", encoding="utf-8") as f:
13
  vocab = json.load(f)
 
14
  stoi = vocab["stoi"]
15
  itos = vocab["itos"]
16
 
17
  def encode(text):
18
+ return [stoi.get(c, 0) for c in text] # 0 للمسافات أو الرموز غير المعروفة
19
 
20
  def decode(ids):
21
  return "".join([itos.get(str(i), "") for i in ids])
22
 
23
+ # 2. تحميل النموذج (مرة واحدة فقط)
24
  try:
25
  model = MedicalMasterAI(vocab_size=115, n_layer=48, n_head=8, n_embd=768)
 
 
26
  model_path = hf_hub_download(repo_id="gijl/Medical-Master-1.5B", filename="pytorch_model.bin")
 
 
27
  state_dict = torch.load(model_path, map_location=device, weights_only=True)
 
 
28
  model.load_state_dict(state_dict, strict=False)
29
+ model.eval()
 
 
30
  model_loaded = True
 
31
  except Exception as e:
32
+ print(f"خطأ في تحميل النموذج: {e}")
33
  model_loaded = False
34
 
35
+ # 3. دالة التوليد بنظام Streaming
36
  def medical_chat(message, history):
37
  if not model_loaded:
38
+ yield "النموذج لم يتم تحميله بشكل صحيح."
39
+ return
40
+
41
+ # بناء البرومبت
42
+ prompt = f"Question: {message} Answer:"
43
  idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
44
 
45
+ generated_text = ""
 
46
 
47
+ # استخدام inference_mode لتسريع المعالج
48
+ with torch.inference_mode():
49
+ for _ in range(150): # تقليل العدد لسرعة الاستجابة
50
+ # القص ليتناسب مع حجم الـ Position Embedding (256)
51
  idx_cond = idx[:, -256:]
52
 
53
  logits = model(idx_cond)
54
+ logits = logits[:, -1, :] / 0.8 # درجة الحرارة
55
+
 
56
  probs = F.softmax(logits, dim=-1)
57
  idx_next = torch.multinomial(probs, num_samples=1)
58
 
59
+ # إضافة الحرف الجديد
60
  idx = torch.cat((idx, idx_next), dim=1)
 
61
 
62
+ char = decode([idx_next.item()])
63
+ generated_text += char
64
+
65
+ # إرسال النص المنتج حتى الآن للواجهة (Streaming)
66
+ yield generated_text
67
+
68
+ # توقف إذا أنتج النموذج علامة توقف (مثل النقطة إذا رغبت)
69
+ if idx_next.item() == stoi.get(".", -1):
70
+ break
71
 
72
+ # 4. واجهة Gradio
73
  demo = gr.ChatInterface(
74
  fn=medical_chat,
75
+ title="Medical Master 1.5B (Streaming Mode)",
76
+ description="إذا تأخر الرد، انتظر قليلاً فالنموذج يولد النص حرفاً بحرف.",
77
  )
78
 
79
  if __name__ == "__main__":