gijl commited on
Commit
a5f0014
·
verified ·
1 Parent(s): cde5e54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -2,12 +2,14 @@ import gradio as gr
2
  import torch
3
  import json
4
  import torch.nn.functional as F
 
5
  from model import MedicalMasterAI # استيراد المعمارية الخاصة بك
6
 
7
  # تحديد الجهاز
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
  # 1. تحميل التوكنايزر المخصص الخاص بك
 
11
  with open("tokenizer_config.json", "r", encoding="utf-8") as f:
12
  vocab = json.load(f)
13
 
@@ -16,7 +18,6 @@ itos = vocab["itos"]
16
 
17
  # دوال التشفير وفك التشفير
18
  def encode(text):
19
- # إذا لم يجد الحرف، يستبدله بمسافة (0)
20
  return [stoi.get(c, 0) for c in text]
21
 
22
  def decode(ids):
@@ -27,14 +28,20 @@ try:
27
  # تهيئة النموذج بنفس الأبعاد الموجودة في كودك
28
  model = MedicalMasterAI(vocab_size=115, n_layer=48, n_head=8, n_embd=768)
29
 
30
- # تحميل الأوزان (تأكد من أن اسم الملف في المساحة هو pytorch_model.bin)
31
- # استخدام weights_only=True لدواعي أمنية ولمنع التحذيرات
32
- state_dict = torch.load("pytorch_model.bin", map_location=device, weights_only=True)
 
 
 
 
 
33
  model.load_state_dict(state_dict)
34
 
35
  model.to(device)
36
  model.eval() # وضع التقييم
37
  model_loaded = True
 
38
  except Exception as e:
39
  print(f"Error loading model: {e}")
40
  model_loaded = False
@@ -42,22 +49,21 @@ except Exception as e:
42
  # 3. دالة المحادثة (التوليد التلقائي)
43
  def medical_chat(message, history):
44
  if not model_loaded:
45
- return "حدث خطأ أثناء تحميل أوزان النموذج. تأكد من وجود ملف pytorch_model.bin"
46
 
47
- # ملاحظة: تم إزالة \n لأنها غير موجودة في قاموسك (stoi)
48
  prompt = f"Question: {message} Answer:"
49
 
50
  # تحويل النص إلى أرقام (Tensors)
51
  idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
52
 
53
- # عدد الأحرف التي سيولدها النموذج (بما أنه Character-level)
54
  max_new_chars = 200
55
 
56
  generated_ids = []
57
 
58
  with torch.no_grad():
59
  for _ in range(max_new_chars):
60
- # قص السياق إذا تجاوز 1024 (الحد الأقصى للـ Position Embedding في كودك)
61
  idx_cond = idx[:, -1024:]
62
 
63
  # تمرير البيانات للنموذج
@@ -66,7 +72,7 @@ def medical_chat(message, history):
66
  # التركيز على الحرف الأخير فقط
67
  logits = logits[:, -1, :]
68
 
69
- # تطبيق الحرارة (Temperature) لتنويع الإجابات
70
  temperature = 0.8
71
  logits = logits / temperature
72
 
 
2
  import torch
3
  import json
4
  import torch.nn.functional as F
5
+ from huggingface_hub import hf_hub_download # استيراد أداة التحميل
6
  from model import MedicalMasterAI # استيراد المعمارية الخاصة بك
7
 
8
  # تحديد الجهاز
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  # 1. تحميل التوكنايزر المخصص الخاص بك
12
+ # نفترض أن هذا الملف موجود محلياً في المساحة
13
  with open("tokenizer_config.json", "r", encoding="utf-8") as f:
14
  vocab = json.load(f)
15
 
 
18
 
19
  # دوال التشفير وفك التشفير
20
  def encode(text):
 
21
  return [stoi.get(c, 0) for c in text]
22
 
23
  def decode(ids):
 
28
  # تهيئة النموذج بنفس الأبعاد الموجودة في كودك
29
  model = MedicalMasterAI(vocab_size=115, n_layer=48, n_head=8, n_embd=768)
30
 
31
+ print("جاري سحب ملف الأوزان من المستودع... قد يستغرق هذا بعض الوقت.")
32
+ # سحب الملف الضخم من مستودعك مباشرة
33
+ # تأكد أن "gijl/Medical-Master-1.5B" هو المسار الصحيح لمستودعك
34
+ model_path = hf_hub_download(repo_id="gijl/Medical-Master-1.5B", filename="pytorch_model.bin")
35
+ print("تم التحميل بنجاح. جاري قراءة الأوزان...")
36
+
37
+ # تحميل الأوزان باستخدام المسار الذي أرجعته دالة التحميل
38
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
39
  model.load_state_dict(state_dict)
40
 
41
  model.to(device)
42
  model.eval() # وضع التقييم
43
  model_loaded = True
44
+ print("النموذج جاهز للعمل!")
45
  except Exception as e:
46
  print(f"Error loading model: {e}")
47
  model_loaded = False
 
49
  # 3. دالة المحادثة (التوليد التلقائي)
50
  def medical_chat(message, history):
51
  if not model_loaded:
52
+ return "حدث خطأ أثناء تحميل أوزان النموذج. يرجى مراجعة سجلات الأخطاء (Logs)."
53
 
 
54
  prompt = f"Question: {message} Answer:"
55
 
56
  # تحويل النص إلى أرقام (Tensors)
57
  idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
58
 
59
+ # عدد الأحرف التي سيولدها النموذج
60
  max_new_chars = 200
61
 
62
  generated_ids = []
63
 
64
  with torch.no_grad():
65
  for _ in range(max_new_chars):
66
+ # قص السياق إذا تجاوز 1024
67
  idx_cond = idx[:, -1024:]
68
 
69
  # تمرير البيانات للنموذج
 
72
  # التركيز على الحرف الأخير فقط
73
  logits = logits[:, -1, :]
74
 
75
+ # تطبيق الحرارة (Temperature)
76
  temperature = 0.8
77
  logits = logits / temperature
78