HK2184 commited on
Commit
b0f0e1c
Β·
verified Β·
1 Parent(s): a9ba36d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -38
app.py CHANGED
@@ -1,15 +1,32 @@
1
  import os
2
-
3
  import gradio as gr
4
-
5
- os.environ["ROCR_VISIBLE_DEVICES"] = "0"
6
- os.environ["HIP_VISIBLE_DEVICES"] = "0"
7
- os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"
8
-
9
- BASE_MODEL = "Qwen/Qwen2-1.5B"
10
- ADAPTER_PATH = "./outputs"
11
-
12
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  EXAMPLES = [
15
  ["Which artery is occluded in inferior MI with ST elevation in II, III, aVF?",
@@ -29,20 +46,29 @@ EXAMPLES = [
29
  def answer(question, opa, opb, opc, opd):
30
  if not question.strip():
31
  return "Please enter a question."
32
-
33
- # simple mock logic (random-ish but believable)
34
- import random
35
- options = [opa, opb, opc, opd]
36
- letters = ["A", "B", "C", "D"]
37
-
38
- idx = random.randint(0, 3)
39
-
40
- return f"""Answer: {letters[idx]}) {options[idx]}
41
-
42
- Explanation:
43
- This is a mock demo running without the full model.
44
- In the real system, a fine-tuned medical LLM analyzes the clinical context
45
- and selects the most appropriate answer based on learned patterns."""
 
 
 
 
 
 
 
 
 
46
 
47
  CSS = """
48
  @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap');
@@ -72,7 +98,6 @@ body, .gradio-container {
72
  padding: 0 20px 60px !important;
73
  }
74
 
75
- /* Header */
76
  #header {
77
  padding: 44px 0 28px;
78
  border-bottom: 1px solid var(--border);
@@ -105,7 +130,6 @@ h1#title {
105
  h1#title em { color: var(--accent); font-style: normal; }
106
  .subtitle { font-size: 14px; color: var(--muted); font-weight: 300; line-height: 1.6; max-width: 520px; }
107
 
108
- /* Stats */
109
  #stats {
110
  display: flex; border: 1px solid var(--border);
111
  border-radius: 12px; overflow: hidden;
@@ -118,7 +142,6 @@ h1#title em { color: var(--accent); font-style: normal; }
118
  .dot { display: inline-block; width: 6px; height: 6px; border-radius: 50%; background: var(--green); margin-right: 4px; animation: blink 2s infinite; }
119
  @keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} }
120
 
121
- /* Inputs */
122
  label span, .label-wrap span {
123
  font-family: 'DM Sans', sans-serif !important;
124
  font-size: 11px !important; font-weight: 500 !important;
@@ -140,7 +163,6 @@ textarea:focus, input[type=text]:focus {
140
  outline: none !important;
141
  }
142
 
143
- /* Section labels */
144
  .section-label {
145
  font-size: 10px; font-weight: 600;
146
  letter-spacing: 0.12em; text-transform: uppercase;
@@ -152,7 +174,6 @@ textarea:focus, input[type=text]:focus {
152
  background: var(--accent); display: inline-block;
153
  }
154
 
155
- /* Button */
156
  button.lg.primary {
157
  background: linear-gradient(135deg, var(--accent2), var(--accent)) !important;
158
  border: none !important; border-radius: 10px !important;
@@ -165,7 +186,6 @@ button.lg.primary {
165
  }
166
  button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
167
 
168
- /* Output */
169
  .out-box textarea {
170
  background: var(--surface2) !important;
171
  border: 1px solid var(--border) !important;
@@ -174,7 +194,6 @@ button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px)
174
  color: var(--text) !important; min-height: 280px !important;
175
  }
176
 
177
- /* Examples */
178
  .examples-holder table {
179
  background: var(--surface) !important;
180
  border: 1px solid var(--border) !important;
@@ -187,7 +206,6 @@ button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px)
187
  }
188
  .examples-holder tr:hover td { background: var(--surface2) !important; cursor: pointer; }
189
 
190
- /* Footer */
191
  #footer {
192
  margin-top: 44px; padding-top: 22px;
193
  border-top: 1px solid var(--border);
@@ -206,7 +224,7 @@ with gr.Blocks(css=CSS, title="MedQA β€” AMD ROCm") as demo:
206
  <div id="header">
207
  <div class="badges">
208
  <span class="badge b-amd">AMD MI300X</span>
209
- <span class="badge b-rocm">ROCm 6.1</span>
210
  <span class="badge b-lora">LoRA Fine-tuned</span>
211
  <span class="badge b-live"><span class="dot"></span>Live Inference</span>
212
  </div>
@@ -217,7 +235,7 @@ with gr.Blocks(css=CSS, title="MedQA β€” AMD ROCm") as demo:
217
  </p>
218
  </div>
219
  <div id="stats">
220
- <div class="stat"><span class="sv">1.5B</span><span class="sl">Parameters</span></div>
221
  <div class="stat"><span class="sv">LoRA</span><span class="sl">Fine-tuning</span></div>
222
  <div class="stat"><span class="sv">193k</span><span class="sl">Training QA</span></div>
223
  <div class="stat"><span class="sv">MI300X</span><span class="sl">AMD GPU</span></div>
@@ -256,17 +274,17 @@ with gr.Blocks(css=CSS, title="MedQA β€” AMD ROCm") as demo:
256
  examples=EXAMPLES,
257
  inputs=[question, opa, opb, opc, opd],
258
  label="",
259
- )
260
 
261
  gr.HTML("""
262
  <div id="footer">
263
  <div class="fl">
264
  Built on <strong>AMD Developer Cloud</strong> &nbsp;Β·&nbsp;
265
- Model: <strong>Qwen2-1.5B + LoRA</strong> &nbsp;Β·&nbsp;
266
  Dataset: <strong>MedMCQA</strong>
267
  </div>
268
  <div class="fr">
269
- <a class="flink" href="https://github.com" target="_blank">GitHub β†’</a>
270
  <a class="flink" href="https://lablab.ai" target="_blank">lablab.ai β†’</a>
271
  <a class="flink" href="https://cloud.amd.com" target="_blank">AMD Cloud β†’</a>
272
  </div>
@@ -276,4 +294,4 @@ with gr.Blocks(css=CSS, title="MedQA β€” AMD ROCm") as demo:
276
  btn.click(fn=answer, inputs=[question, opa, opb, opc, opd], outputs=output)
277
 
278
  if __name__ == "__main__":
279
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  import os
2
+ import torch
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from peft import PeftModel
6
+
7
+ # ← CHANGE 1: ROCm env vars removed
8
+
9
+ BASE_MODEL = "Qwen/Qwen3-1.7B"
10
+ ADAPTER_PATH = "HK2184/medqa-qwen3-lora" # ← CHANGE 2: HF Hub instead of ./outputs
11
+
12
+ print("Loading tokenizer...")
13
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ tokenizer.padding_side = "left"
16
+
17
+ print("Loading model...")
18
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # ← CHANGE 3: auto dtype
19
+ base = AutoModelForCausalLM.from_pretrained(
20
+ BASE_MODEL,
21
+ dtype=DTYPE,
22
+ device_map="auto",
23
+ trust_remote_code=True,
24
+ low_cpu_mem_usage=True,
25
+ )
26
+ model = PeftModel.from_pretrained(base, ADAPTER_PATH)
27
+ model = model.merge_and_unload()
28
+ model.eval()
29
+ print("Ready!")
30
 
31
  EXAMPLES = [
32
  ["Which artery is occluded in inferior MI with ST elevation in II, III, aVF?",
 
46
  def answer(question, opa, opb, opc, opd):
47
  if not question.strip():
48
  return "Please enter a question."
49
+ if not all([opa.strip(), opb.strip(), opc.strip(), opd.strip()]):
50
+ return "Please fill in all four options."
51
+ prompt = (
52
+ f"### Question:\n{question}\n\n"
53
+ f"### Options:\n"
54
+ f"A) {opa}\nB) {opb}\nC) {opc}\nD) {opd}\n\n"
55
+ f"### Answer:\n"
56
+ )
57
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
+ with torch.no_grad():
59
+ out = model.generate(
60
+ **inputs,
61
+ max_new_tokens=200,
62
+ do_sample=True,
63
+ temperature=0.7,
64
+ top_p=0.9,
65
+ top_k=50,
66
+ repetition_penalty=1.3,
67
+ eos_token_id=tokenizer.eos_token_id,
68
+ pad_token_id=tokenizer.eos_token_id,
69
+ )
70
+ new = out[0][inputs["input_ids"].shape[-1]:]
71
+ return tokenizer.decode(new, skip_special_tokens=True)
72
 
73
  CSS = """
74
  @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap');
 
98
  padding: 0 20px 60px !important;
99
  }
100
 
 
101
  #header {
102
  padding: 44px 0 28px;
103
  border-bottom: 1px solid var(--border);
 
130
  h1#title em { color: var(--accent); font-style: normal; }
131
  .subtitle { font-size: 14px; color: var(--muted); font-weight: 300; line-height: 1.6; max-width: 520px; }
132
 
 
133
  #stats {
134
  display: flex; border: 1px solid var(--border);
135
  border-radius: 12px; overflow: hidden;
 
142
  .dot { display: inline-block; width: 6px; height: 6px; border-radius: 50%; background: var(--green); margin-right: 4px; animation: blink 2s infinite; }
143
  @keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} }
144
 
 
145
  label span, .label-wrap span {
146
  font-family: 'DM Sans', sans-serif !important;
147
  font-size: 11px !important; font-weight: 500 !important;
 
163
  outline: none !important;
164
  }
165
 
 
166
  .section-label {
167
  font-size: 10px; font-weight: 600;
168
  letter-spacing: 0.12em; text-transform: uppercase;
 
174
  background: var(--accent); display: inline-block;
175
  }
176
 
 
177
  button.lg.primary {
178
  background: linear-gradient(135deg, var(--accent2), var(--accent)) !important;
179
  border: none !important; border-radius: 10px !important;
 
186
  }
187
  button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
188
 
 
189
  .out-box textarea {
190
  background: var(--surface2) !important;
191
  border: 1px solid var(--border) !important;
 
194
  color: var(--text) !important; min-height: 280px !important;
195
  }
196
 
 
197
  .examples-holder table {
198
  background: var(--surface) !important;
199
  border: 1px solid var(--border) !important;
 
206
  }
207
  .examples-holder tr:hover td { background: var(--surface2) !important; cursor: pointer; }
208
 
 
209
  #footer {
210
  margin-top: 44px; padding-top: 22px;
211
  border-top: 1px solid var(--border);
 
224
  <div id="header">
225
  <div class="badges">
226
  <span class="badge b-amd">AMD MI300X</span>
227
+ <span class="badge b-rocm">ROCm 7.2</span>
228
  <span class="badge b-lora">LoRA Fine-tuned</span>
229
  <span class="badge b-live"><span class="dot"></span>Live Inference</span>
230
  </div>
 
235
  </p>
236
  </div>
237
  <div id="stats">
238
+ <div class="stat"><span class="sv">1.7B</span><span class="sl">Parameters</span></div>
239
  <div class="stat"><span class="sv">LoRA</span><span class="sl">Fine-tuning</span></div>
240
  <div class="stat"><span class="sv">193k</span><span class="sl">Training QA</span></div>
241
  <div class="stat"><span class="sv">MI300X</span><span class="sl">AMD GPU</span></div>
 
274
  examples=EXAMPLES,
275
  inputs=[question, opa, opb, opc, opd],
276
  label="",
277
+ )
278
 
279
  gr.HTML("""
280
  <div id="footer">
281
  <div class="fl">
282
  Built on <strong>AMD Developer Cloud</strong> &nbsp;Β·&nbsp;
283
+ Model: <strong>Qwen3-1.7B + LoRA</strong> &nbsp;Β·&nbsp;
284
  Dataset: <strong>MedMCQA</strong>
285
  </div>
286
  <div class="fr">
287
+ <a class="flink" href="https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm" target="_blank">GitHub β†’</a>
288
  <a class="flink" href="https://lablab.ai" target="_blank">lablab.ai β†’</a>
289
  <a class="flink" href="https://cloud.amd.com" target="_blank">AMD Cloud β†’</a>
290
  </div>
 
294
  btn.click(fn=answer, inputs=[question, opa, opb, opc, opd], outputs=output)
295
 
296
  if __name__ == "__main__":
297
+ demo.launch() # ← CHANGE 4: no server_name/port/share for HF Spaces