HK2184 commited on
Commit
d6f9870
·
verified ·
1 Parent(s): 8276d5f
Files changed (1) hide show
  1. app.py +305 -0
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from peft import PeftModel
6
+
7
+ os.environ["ROCR_VISIBLE_DEVICES"] = "0"
8
+ os.environ["HIP_VISIBLE_DEVICES"] = "0"
9
+ os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"
10
+
11
+ BASE_MODEL = "Qwen/Qwen2-1.5B"
12
+ ADAPTER_PATH = "./outputs"
13
+
14
+ print("Loading tokenizer...")
15
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
16
+ tokenizer.pad_token = tokenizer.eos_token
17
+ tokenizer.padding_side = "left"
18
+
19
+ print("Loading model...")
20
+ base = AutoModelForCausalLM.from_pretrained(
21
+ BASE_MODEL,
22
+ dtype=torch.bfloat16,
23
+ device_map="auto",
24
+ trust_remote_code=True,
25
+ )
26
+ model = PeftModel.from_pretrained(base, ADAPTER_PATH)
27
+ model = model.merge_and_unload()
28
+ model.eval()
29
+ print("Ready!")
30
+
31
+ EXAMPLES = [
32
+ ["Which artery is occluded in inferior MI with ST elevation in II, III, aVF?",
33
+ "Left anterior descending artery", "Right coronary artery",
34
+ "Left circumflex artery", "Left main coronary artery"],
35
+ ["First-line treatment for hypertensive emergency?",
36
+ "Oral amlodipine", "IV labetalol or IV nitroprusside",
37
+ "Sublingual nifedipine", "IM hydralazine"],
38
+ ["Most common cause of community-acquired pneumonia?",
39
+ "Klebsiella pneumoniae", "Streptococcus pneumoniae",
40
+ "Haemophilus influenzae", "Mycoplasma pneumoniae"],
41
+ ["Drug of choice for absence seizures?",
42
+ "Phenytoin", "Carbamazepine",
43
+ "Ethosuximide", "Valproate"],
44
+ ]
45
+
46
+ def answer(question, opa, opb, opc, opd):
47
+ if not question.strip():
48
+ return "Please enter a question."
49
+ if not all([opa.strip(), opb.strip(), opc.strip(), opd.strip()]):
50
+ return "Please fill in all four options."
51
+ prompt = (
52
+ f"### Question:\n{question}\n\n"
53
+ f"### Options:\n"
54
+ f"A) {opa}\nB) {opb}\nC) {opc}\nD) {opd}\n\n"
55
+ f"### Answer:\n"
56
+ )
57
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
+ with torch.no_grad():
59
+ out = model.generate(
60
+ **inputs,
61
+ max_new_tokens=200,
62
+ do_sample=True,
63
+ temperature=0.7,
64
+ top_p=0.9,
65
+ top_k=50,
66
+ repetition_penalty=1.3,
67
+ eos_token_id=tokenizer.eos_token_id,
68
+ pad_token_id=tokenizer.eos_token_id,
69
+ )
70
+ new = out[0][inputs["input_ids"].shape[-1]:]
71
+ return tokenizer.decode(new, skip_special_tokens=True)
72
+
73
+ CSS = """
74
+ @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap');
75
+
76
+ :root {
77
+ --bg: #080d1a;
78
+ --surface: #0f1624;
79
+ --surface2: #162030;
80
+ --border: #1a3356;
81
+ --accent: #00c8f0;
82
+ --accent2: #0055ff;
83
+ --green: #00f0a0;
84
+ --text: #deeeff;
85
+ --muted: #4a6080;
86
+ --danger: #ff3366;
87
+ }
88
+
89
+ body, .gradio-container {
90
+ background: var(--bg) !important;
91
+ font-family: 'DM Sans', sans-serif !important;
92
+ color: var(--text) !important;
93
+ }
94
+
95
+ .gradio-container {
96
+ max-width: 1080px !important;
97
+ margin: 0 auto !important;
98
+ padding: 0 20px 60px !important;
99
+ }
100
+
101
+ /* Header */
102
+ #header {
103
+ padding: 44px 0 28px;
104
+ border-bottom: 1px solid var(--border);
105
+ margin-bottom: 32px;
106
+ position: relative;
107
+ }
108
+ #header::after {
109
+ content: '';
110
+ position: absolute;
111
+ bottom: -1px; left: 0; right: 0; height: 1px;
112
+ background: linear-gradient(90deg, var(--accent2), var(--accent), var(--green));
113
+ }
114
+ .badges { display: flex; gap: 8px; margin-bottom: 14px; flex-wrap: wrap; }
115
+ .badge {
116
+ font-size: 10px; font-weight: 600;
117
+ letter-spacing: 0.1em; text-transform: uppercase;
118
+ padding: 3px 9px; border-radius: 4px; border: 1px solid;
119
+ }
120
+ .b-amd { color: #ff6030; border-color: #ff603030; background: #ff603010; }
121
+ .b-rocm { color: var(--accent); border-color: #00c8f030; background: #00c8f008; }
122
+ .b-lora { color: var(--green); border-color: #00f0a030; background: #00f0a008; }
123
+ .b-live { color: #ffcc00; border-color: #ffcc0030; background: #ffcc0008; }
124
+
125
+ h1#title {
126
+ font-family: 'Syne', sans-serif !important;
127
+ font-size: 42px !important; font-weight: 800 !important;
128
+ letter-spacing: -0.03em !important; line-height: 1 !important;
129
+ color: var(--text) !important; margin-bottom: 10px !important;
130
+ }
131
+ h1#title em { color: var(--accent); font-style: normal; }
132
+ .subtitle { font-size: 14px; color: var(--muted); font-weight: 300; line-height: 1.6; max-width: 520px; }
133
+
134
+ /* Stats */
135
+ #stats {
136
+ display: flex; border: 1px solid var(--border);
137
+ border-radius: 12px; overflow: hidden;
138
+ background: var(--surface); margin-bottom: 28px;
139
+ }
140
+ .stat { flex: 1; padding: 14px 16px; text-align: center; border-right: 1px solid var(--border); }
141
+ .stat:last-child { border-right: none; }
142
+ .sv { font-family: 'Syne', sans-serif; font-size: 20px; font-weight: 700; color: var(--accent); display: block; }
143
+ .sl { font-size: 10px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.08em; }
144
+ .dot { display: inline-block; width: 6px; height: 6px; border-radius: 50%; background: var(--green); margin-right: 4px; animation: blink 2s infinite; }
145
+ @keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} }
146
+
147
+ /* Inputs */
148
+ label span, .label-wrap span {
149
+ font-family: 'DM Sans', sans-serif !important;
150
+ font-size: 11px !important; font-weight: 500 !important;
151
+ color: var(--muted) !important; text-transform: uppercase !important;
152
+ letter-spacing: 0.07em !important;
153
+ }
154
+ textarea, input[type=text] {
155
+ background: var(--surface2) !important;
156
+ border: 1px solid var(--border) !important;
157
+ border-radius: 10px !important;
158
+ color: var(--text) !important;
159
+ font-family: 'DM Sans', sans-serif !important;
160
+ font-size: 14px !important; line-height: 1.6 !important;
161
+ transition: border-color 0.2s, box-shadow 0.2s !important;
162
+ }
163
+ textarea:focus, input[type=text]:focus {
164
+ border-color: var(--accent) !important;
165
+ box-shadow: 0 0 0 3px #00c8f012 !important;
166
+ outline: none !important;
167
+ }
168
+
169
+ /* Section labels */
170
+ .section-label {
171
+ font-size: 10px; font-weight: 600;
172
+ letter-spacing: 0.12em; text-transform: uppercase;
173
+ color: var(--muted); margin-bottom: 10px;
174
+ display: flex; align-items: center; gap: 7px;
175
+ }
176
+ .section-label::before {
177
+ content: ''; width: 5px; height: 5px; border-radius: 50%;
178
+ background: var(--accent); display: inline-block;
179
+ }
180
+
181
+ /* Button */
182
+ button.lg.primary {
183
+ background: linear-gradient(135deg, var(--accent2), var(--accent)) !important;
184
+ border: none !important; border-radius: 10px !important;
185
+ color: #fff !important; font-family: 'Syne', sans-serif !important;
186
+ font-size: 14px !important; font-weight: 700 !important;
187
+ letter-spacing: 0.04em !important; padding: 14px !important;
188
+ width: 100% !important; margin-top: 14px !important;
189
+ cursor: pointer !important;
190
+ transition: opacity 0.2s, transform 0.15s !important;
191
+ }
192
+ button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
193
+
194
+ /* Output */
195
+ .out-box textarea {
196
+ background: var(--surface2) !important;
197
+ border: 1px solid var(--border) !important;
198
+ border-radius: 10px !important;
199
+ font-size: 14px !important; line-height: 1.8 !important;
200
+ color: var(--text) !important; min-height: 280px !important;
201
+ }
202
+
203
+ /* Examples */
204
+ .examples-holder table {
205
+ background: var(--surface) !important;
206
+ border: 1px solid var(--border) !important;
207
+ border-radius: 10px !important; overflow: hidden !important;
208
+ }
209
+ .examples-holder td, .examples-holder th {
210
+ background: transparent !important; color: var(--text) !important;
211
+ font-size: 13px !important; border-color: var(--border) !important;
212
+ font-family: 'DM Sans', sans-serif !important;
213
+ }
214
+ .examples-holder tr:hover td { background: var(--surface2) !important; cursor: pointer; }
215
+
216
+ /* Footer */
217
+ #footer {
218
+ margin-top: 44px; padding-top: 22px;
219
+ border-top: 1px solid var(--border);
220
+ display: flex; justify-content: space-between;
221
+ align-items: center; flex-wrap: wrap; gap: 10px;
222
+ }
223
+ .fl { font-size: 12px; color: var(--muted); }
224
+ .fl strong { color: var(--text); }
225
+ .fr { display: flex; gap: 14px; }
226
+ .flink { font-size: 12px; color: var(--accent); text-decoration: none; }
227
+ """
228
+
229
+ with gr.Blocks(css=CSS, title="MedQA — AMD ROCm") as demo:
230
+
231
+ gr.HTML("""
232
+ <div id="header">
233
+ <div class="badges">
234
+ <span class="badge b-amd">AMD MI300X</span>
235
+ <span class="badge b-rocm">ROCm 6.1</span>
236
+ <span class="badge b-lora">LoRA Fine-tuned</span>
237
+ <span class="badge b-live"><span class="dot"></span>Live Inference</span>
238
+ </div>
239
+ <h1 id="title">Med<em>QA</em> Assistant</h1>
240
+ <p class="subtitle">
241
+ Clinical question-answering AI fine-tuned on MedMCQA.
242
+ Running on AMD Instinct MI300X via ROCm — no CUDA required.
243
+ </p>
244
+ </div>
245
+ <div id="stats">
246
+ <div class="stat"><span class="sv">1.5B</span><span class="sl">Parameters</span></div>
247
+ <div class="stat"><span class="sv">LoRA</span><span class="sl">Fine-tuning</span></div>
248
+ <div class="stat"><span class="sv">193k</span><span class="sl">Training QA</span></div>
249
+ <div class="stat"><span class="sv">MI300X</span><span class="sl">AMD GPU</span></div>
250
+ <div class="stat"><span class="sv">bf16</span><span class="sl">Precision</span></div>
251
+ </div>
252
+ """)
253
+
254
+ with gr.Row():
255
+ with gr.Column(scale=1):
256
+ gr.HTML('<div class="section-label">Clinical Question</div>')
257
+ question = gr.Textbox(
258
+ label="",
259
+ placeholder="e.g. A 45-year-old presents with sudden onset severe headache...",
260
+ lines=4,
261
+ )
262
+ gr.HTML('<div class="section-label" style="margin-top:14px">Answer Options</div>')
263
+ with gr.Row():
264
+ opa = gr.Textbox(label="Option A", placeholder="First option")
265
+ opb = gr.Textbox(label="Option B", placeholder="Second option")
266
+ with gr.Row():
267
+ opc = gr.Textbox(label="Option C", placeholder="Third option")
268
+ opd = gr.Textbox(label="Option D", placeholder="Fourth option")
269
+ btn = gr.Button("Analyze Question", variant="primary")
270
+
271
+ with gr.Column(scale=1):
272
+ gr.HTML('<div class="section-label">AI Answer & Reasoning</div>')
273
+ output = gr.Textbox(
274
+ label="",
275
+ placeholder="Answer and clinical explanation will appear here...",
276
+ lines=14,
277
+ elem_classes=["out-box"],
278
+ )
279
+
280
+ gr.HTML('<div class="section-label" style="margin-top:24px">Sample Questions — click any to load</div>')
281
+ gr.Examples(
282
+ examples=EXAMPLES,
283
+ inputs=[question, opa, opb, opc, opd],
284
+ label="",
285
+ )
286
+
287
+ gr.HTML("""
288
+ <div id="footer">
289
+ <div class="fl">
290
+ Built on <strong>AMD Developer Cloud</strong> &nbsp;·&nbsp;
291
+ Model: <strong>Qwen2-1.5B + LoRA</strong> &nbsp;·&nbsp;
292
+ Dataset: <strong>MedMCQA</strong>
293
+ </div>
294
+ <div class="fr">
295
+ <a class="flink" href="https://github.com" target="_blank">GitHub →</a>
296
+ <a class="flink" href="https://lablab.ai" target="_blank">lablab.ai →</a>
297
+ <a class="flink" href="https://cloud.amd.com" target="_blank">AMD Cloud →</a>
298
+ </div>
299
+ </div>
300
+ """)
301
+
302
+ btn.click(fn=answer, inputs=[question, opa, opb, opc, opd], outputs=output)
303
+
304
+ if __name__ == "__main__":
305
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)