gijl commited on
Commit
5201445
·
verified ·
1 Parent(s): 1defb5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -37
app.py CHANGED
@@ -1,79 +1,74 @@
1
  import gradio as gr
2
  import torch
3
  import json
4
- import torch.nn.functional as F
5
- from huggingface_hub import hf_hub_download
6
- from model import MedicalMasterAI
7
 
8
- # إعداد الجهاز
9
- device = torch.device("cpu") # المساحات المجانية تستخدم المعالج
10
 
11
- # 1. تحميل التوكنايزر
12
  with open("tokenizer_config.json", "r", encoding="utf-8") as f:
13
  vocab = json.load(f)
14
  stoi = vocab["stoi"]
15
  itos = vocab["itos"]
16
 
17
- def encode(text):
18
- return [stoi.get(c, 0) for c in text] # 0 للمسافات أو الرموز غير المعروفة
19
 
20
- def decode(ids):
21
- return "".join([itos.get(str(i), "") for i in ids])
22
-
23
- # 2. تحميل النموذج (مرة واحدة فقط)
24
  try:
25
- model = MedicalMasterAI(vocab_size=115, n_layer=48, n_head=8, n_embd=768)
26
- model_path = hf_hub_download(repo_id="gijl/Medical-Master-1.5B", filename="pytorch_model.bin")
27
- state_dict = torch.load(model_path, map_location=device, weights_only=True)
28
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
29
  model.eval()
30
  model_loaded = True
 
31
  except Exception as e:
32
- print(f"خطأ في تحميل النموذج: {e}")
33
  model_loaded = False
34
 
35
- # 3. دالة التوليد بنظام Streaming
36
  def medical_chat(message, history):
37
  if not model_loaded:
38
- yield "النموذج لم يتم تحميله بشكل صحيح."
39
  return
40
 
41
- # بناء البرومبت
42
  prompt = f"Question: {message} Answer:"
43
- idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
44
 
45
  generated_text = ""
46
 
47
- # استخدام inference_mode لتسريع المعالج
48
- with torch.inference_mode():
49
- for _ in range(150): # تقليل العدد لسرعة الاستجابة
50
- # القص ليتناسب مع حجم الـ Position Embedding (256)
51
- idx_cond = idx[:, -256:]
52
 
53
- logits = model(idx_cond)
54
- logits = logits[:, -1, :] / 0.8 # درجة الحرارة
 
55
 
56
- probs = F.softmax(logits, dim=-1)
57
  idx_next = torch.multinomial(probs, num_samples=1)
58
 
59
- # إضافة الحرف الجديد
60
- idx = torch.cat((idx, idx_next), dim=1)
61
-
62
  char = decode([idx_next.item()])
63
  generated_text += char
64
 
65
- # إرسال النص المنتج حتى الآن للواجهة (Streaming)
66
  yield generated_text
67
 
68
- # توقف إذا أنتج النموذج علامة توقف (مثل النقطة إذا رغبت)
69
  if idx_next.item() == stoi.get(".", -1):
70
  break
71
 
72
- # 4. واجهة Gradio
73
  demo = gr.ChatInterface(
74
  fn=medical_chat,
75
- title="Medical Master 1.5B (Streaming Mode)",
76
- description="إذا تأخر الرد، انتظر قليلاً فالنموذج يولد النص حرفاً بحرف.",
77
  )
78
 
79
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  import json
4
+ from transformers import AutoModelForCausalLM
 
 
5
 
6
+ device = torch.device("cpu")
 
7
 
8
+ # 1. تحميل التوكنايزر المخصص الخاص بك (بدون تعديل الملف)
9
  with open("tokenizer_config.json", "r", encoding="utf-8") as f:
10
  vocab = json.load(f)
11
  stoi = vocab["stoi"]
12
  itos = vocab["itos"]
13
 
14
+ def encode(text): return [stoi.get(c, 0) for c in text]
15
+ def decode(ids): return "".join([itos.get(str(i), "") for i in ids])
16
 
17
+ # 2. تحميل النموذج باستخدام مكتبة Transformers مباشرة
 
 
 
18
  try:
19
+ print("جاري تحميل النموذج من Hugging Face...")
20
+ # هذا السطر سيقرأ config.json و pytorch_model.bin بشكل صحيح ومطابق 100%
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ "gijl/Medical-Master-1.5B",
23
+ torch_dtype=torch.float32,
24
+ low_cpu_mem_usage=True
25
+ )
26
+ model.to(device)
27
  model.eval()
28
  model_loaded = True
29
+ print("تم التحميل بنجاح وتم مطابقة الأوزان!")
30
  except Exception as e:
31
+ print(f"Error: {e}")
32
  model_loaded = False
33
 
34
+ # 3. دالة المحادثة (Streaming)
35
  def medical_chat(message, history):
36
  if not model_loaded:
37
+ yield "حدث خطأ في تحميل النموذج."
38
  return
39
 
 
40
  prompt = f"Question: {message} Answer:"
41
+ input_ids = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
42
 
43
  generated_text = ""
44
 
45
+ with torch.no_grad():
46
+ for _ in range(150): # توليد 150 حرف
47
+ # الحد الأقصى للسياق هو 256 كما هو في config.json
48
+ idx_cond = input_ids[:, -256:]
 
49
 
50
+ # تمرير البيانات لنموذج HF
51
+ outputs = model(input_ids=idx_cond)
52
+ logits = outputs.logits[:, -1, :] / 0.7 # حرارة 0.7 لتقليل العشوائية
53
 
54
+ probs = torch.nn.functional.softmax(logits, dim=-1)
55
  idx_next = torch.multinomial(probs, num_samples=1)
56
 
57
+ input_ids = torch.cat((input_ids, idx_next), dim=1)
 
 
58
  char = decode([idx_next.item()])
59
  generated_text += char
60
 
 
61
  yield generated_text
62
 
63
+ # التوقف إذا وصل لنقطة
64
  if idx_next.item() == stoi.get(".", -1):
65
  break
66
 
67
+ # 4. بناء الواجهة
68
  demo = gr.ChatInterface(
69
  fn=medical_chat,
70
+ title="Medical Master 1.5B",
71
+ description="مساعد طبي ذكي يعمل بحروف اللغة العربية والإنجليزية.",
72
  )
73
 
74
  if __name__ == "__main__":