HK2184's picture
Update app.py
b0f0e1c verified
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# ← CHANGE 1: ROCm env vars removed
BASE_MODEL = "Qwen/Qwen3-1.7B"
ADAPTER_PATH = "HK2184/medqa-qwen3-lora" # ← CHANGE 2: HF Hub instead of ./outputs
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
print("Loading model...")
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # ← CHANGE 3: auto dtype
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=DTYPE,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(base, ADAPTER_PATH)
model = model.merge_and_unload()
model.eval()
print("Ready!")
EXAMPLES = [
["Which artery is occluded in inferior MI with ST elevation in II, III, aVF?",
"Left anterior descending artery", "Right coronary artery",
"Left circumflex artery", "Left main coronary artery"],
["First-line treatment for hypertensive emergency?",
"Oral amlodipine", "IV labetalol or IV nitroprusside",
"Sublingual nifedipine", "IM hydralazine"],
["Most common cause of community-acquired pneumonia?",
"Klebsiella pneumoniae", "Streptococcus pneumoniae",
"Haemophilus influenzae", "Mycoplasma pneumoniae"],
["Drug of choice for absence seizures?",
"Phenytoin", "Carbamazepine",
"Ethosuximide", "Valproate"],
]
def answer(question, opa, opb, opc, opd):
if not question.strip():
return "Please enter a question."
if not all([opa.strip(), opb.strip(), opc.strip(), opd.strip()]):
return "Please fill in all four options."
prompt = (
f"### Question:\n{question}\n\n"
f"### Options:\n"
f"A) {opa}\nB) {opb}\nC) {opc}\nD) {opd}\n\n"
f"### Answer:\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.3,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
new = out[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new, skip_special_tokens=True)
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap');
:root {
--bg: #080d1a;
--surface: #0f1624;
--surface2: #162030;
--border: #1a3356;
--accent: #00c8f0;
--accent2: #0055ff;
--green: #00f0a0;
--text: #deeeff;
--muted: #4a6080;
--danger: #ff3366;
}
body, .gradio-container {
background: var(--bg) !important;
font-family: 'DM Sans', sans-serif !important;
color: var(--text) !important;
}
.gradio-container {
max-width: 1080px !important;
margin: 0 auto !important;
padding: 0 20px 60px !important;
}
#header {
padding: 44px 0 28px;
border-bottom: 1px solid var(--border);
margin-bottom: 32px;
position: relative;
}
#header::after {
content: '';
position: absolute;
bottom: -1px; left: 0; right: 0; height: 1px;
background: linear-gradient(90deg, var(--accent2), var(--accent), var(--green));
}
.badges { display: flex; gap: 8px; margin-bottom: 14px; flex-wrap: wrap; }
.badge {
font-size: 10px; font-weight: 600;
letter-spacing: 0.1em; text-transform: uppercase;
padding: 3px 9px; border-radius: 4px; border: 1px solid;
}
.b-amd { color: #ff6030; border-color: #ff603030; background: #ff603010; }
.b-rocm { color: var(--accent); border-color: #00c8f030; background: #00c8f008; }
.b-lora { color: var(--green); border-color: #00f0a030; background: #00f0a008; }
.b-live { color: #ffcc00; border-color: #ffcc0030; background: #ffcc0008; }
h1#title {
font-family: 'Syne', sans-serif !important;
font-size: 42px !important; font-weight: 800 !important;
letter-spacing: -0.03em !important; line-height: 1 !important;
color: var(--text) !important; margin-bottom: 10px !important;
}
h1#title em { color: var(--accent); font-style: normal; }
.subtitle { font-size: 14px; color: var(--muted); font-weight: 300; line-height: 1.6; max-width: 520px; }
#stats {
display: flex; border: 1px solid var(--border);
border-radius: 12px; overflow: hidden;
background: var(--surface); margin-bottom: 28px;
}
.stat { flex: 1; padding: 14px 16px; text-align: center; border-right: 1px solid var(--border); }
.stat:last-child { border-right: none; }
.sv { font-family: 'Syne', sans-serif; font-size: 20px; font-weight: 700; color: var(--accent); display: block; }
.sl { font-size: 10px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.08em; }
.dot { display: inline-block; width: 6px; height: 6px; border-radius: 50%; background: var(--green); margin-right: 4px; animation: blink 2s infinite; }
@keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} }
label span, .label-wrap span {
font-family: 'DM Sans', sans-serif !important;
font-size: 11px !important; font-weight: 500 !important;
color: var(--muted) !important; text-transform: uppercase !important;
letter-spacing: 0.07em !important;
}
textarea, input[type=text] {
background: var(--surface2) !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important;
color: var(--text) !important;
font-family: 'DM Sans', sans-serif !important;
font-size: 14px !important; line-height: 1.6 !important;
transition: border-color 0.2s, box-shadow 0.2s !important;
}
textarea:focus, input[type=text]:focus {
border-color: var(--accent) !important;
box-shadow: 0 0 0 3px #00c8f012 !important;
outline: none !important;
}
.section-label {
font-size: 10px; font-weight: 600;
letter-spacing: 0.12em; text-transform: uppercase;
color: var(--muted); margin-bottom: 10px;
display: flex; align-items: center; gap: 7px;
}
.section-label::before {
content: ''; width: 5px; height: 5px; border-radius: 50%;
background: var(--accent); display: inline-block;
}
button.lg.primary {
background: linear-gradient(135deg, var(--accent2), var(--accent)) !important;
border: none !important; border-radius: 10px !important;
color: #fff !important; font-family: 'Syne', sans-serif !important;
font-size: 14px !important; font-weight: 700 !important;
letter-spacing: 0.04em !important; padding: 14px !important;
width: 100% !important; margin-top: 14px !important;
cursor: pointer !important;
transition: opacity 0.2s, transform 0.15s !important;
}
button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
.out-box textarea {
background: var(--surface2) !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important;
font-size: 14px !important; line-height: 1.8 !important;
color: var(--text) !important; min-height: 280px !important;
}
.examples-holder table {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important; overflow: hidden !important;
}
.examples-holder td, .examples-holder th {
background: transparent !important; color: var(--text) !important;
font-size: 13px !important; border-color: var(--border) !important;
font-family: 'DM Sans', sans-serif !important;
}
.examples-holder tr:hover td { background: var(--surface2) !important; cursor: pointer; }
#footer {
margin-top: 44px; padding-top: 22px;
border-top: 1px solid var(--border);
display: flex; justify-content: space-between;
align-items: center; flex-wrap: wrap; gap: 10px;
}
.fl { font-size: 12px; color: var(--muted); }
.fl strong { color: var(--text); }
.fr { display: flex; gap: 14px; }
.flink { font-size: 12px; color: var(--accent); text-decoration: none; }
"""
with gr.Blocks(css=CSS, title="MedQA β€” AMD ROCm") as demo:
gr.HTML("""
<div id="header">
<div class="badges">
<span class="badge b-amd">AMD MI300X</span>
<span class="badge b-rocm">ROCm 7.2</span>
<span class="badge b-lora">LoRA Fine-tuned</span>
<span class="badge b-live"><span class="dot"></span>Live Inference</span>
</div>
<h1 id="title">Med<em>QA</em> Assistant</h1>
<p class="subtitle">
Clinical question-answering AI fine-tuned on MedMCQA.
Running on AMD Instinct MI300X via ROCm β€” no CUDA required.
</p>
</div>
<div id="stats">
<div class="stat"><span class="sv">1.7B</span><span class="sl">Parameters</span></div>
<div class="stat"><span class="sv">LoRA</span><span class="sl">Fine-tuning</span></div>
<div class="stat"><span class="sv">193k</span><span class="sl">Training QA</span></div>
<div class="stat"><span class="sv">MI300X</span><span class="sl">AMD GPU</span></div>
<div class="stat"><span class="sv">bf16</span><span class="sl">Precision</span></div>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML('<div class="section-label">Clinical Question</div>')
question = gr.Textbox(
label="",
placeholder="e.g. A 45-year-old presents with sudden onset severe headache...",
lines=4,
)
gr.HTML('<div class="section-label" style="margin-top:14px">Answer Options</div>')
with gr.Row():
opa = gr.Textbox(label="Option A", placeholder="First option")
opb = gr.Textbox(label="Option B", placeholder="Second option")
with gr.Row():
opc = gr.Textbox(label="Option C", placeholder="Third option")
opd = gr.Textbox(label="Option D", placeholder="Fourth option")
btn = gr.Button("Analyze Question", variant="primary")
with gr.Column(scale=1):
gr.HTML('<div class="section-label">AI Answer & Reasoning</div>')
output = gr.Textbox(
label="",
placeholder="Answer and clinical explanation will appear here...",
lines=14,
elem_classes=["out-box"],
)
gr.HTML('<div class="section-label" style="margin-top:24px">Sample Questions β€” click any to load</div>')
gr.Examples(
examples=EXAMPLES,
inputs=[question, opa, opb, opc, opd],
label="",
)
gr.HTML("""
<div id="footer">
<div class="fl">
Built on <strong>AMD Developer Cloud</strong> &nbsp;Β·&nbsp;
Model: <strong>Qwen3-1.7B + LoRA</strong> &nbsp;Β·&nbsp;
Dataset: <strong>MedMCQA</strong>
</div>
<div class="fr">
<a class="flink" href="https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm" target="_blank">GitHub β†’</a>
<a class="flink" href="https://lablab.ai" target="_blank">lablab.ai β†’</a>
<a class="flink" href="https://cloud.amd.com" target="_blank">AMD Cloud β†’</a>
</div>
</div>
""")
btn.click(fn=answer, inputs=[question, opa, opb, opc, opd], outputs=output)
if __name__ == "__main__":
demo.launch() # ← CHANGE 4: no server_name/port/share for HF Spaces