import os import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel # ← CHANGE 1: ROCm env vars removed BASE_MODEL = "Qwen/Qwen3-1.7B" ADAPTER_PATH = "HK2184/medqa-qwen3-lora" # ← CHANGE 2: HF Hub instead of ./outputs print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" print("Loading model...") DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # ← CHANGE 3: auto dtype base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, dtype=DTYPE, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ) model = PeftModel.from_pretrained(base, ADAPTER_PATH) model = model.merge_and_unload() model.eval() print("Ready!") EXAMPLES = [ ["Which artery is occluded in inferior MI with ST elevation in II, III, aVF?", "Left anterior descending artery", "Right coronary artery", "Left circumflex artery", "Left main coronary artery"], ["First-line treatment for hypertensive emergency?", "Oral amlodipine", "IV labetalol or IV nitroprusside", "Sublingual nifedipine", "IM hydralazine"], ["Most common cause of community-acquired pneumonia?", "Klebsiella pneumoniae", "Streptococcus pneumoniae", "Haemophilus influenzae", "Mycoplasma pneumoniae"], ["Drug of choice for absence seizures?", "Phenytoin", "Carbamazepine", "Ethosuximide", "Valproate"], ] def answer(question, opa, opb, opc, opd): if not question.strip(): return "Please enter a question." if not all([opa.strip(), opb.strip(), opc.strip(), opd.strip()]): return "Please fill in all four options." prompt = ( f"### Question:\n{question}\n\n" f"### Options:\n" f"A) {opa}\nB) {opb}\nC) {opc}\nD) {opd}\n\n" f"### Answer:\n" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9, top_k=50, repetition_penalty=1.3, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, ) new = out[0][inputs["input_ids"].shape[-1]:] return tokenizer.decode(new, skip_special_tokens=True) CSS = """ @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap'); :root { --bg: #080d1a; --surface: #0f1624; --surface2: #162030; --border: #1a3356; --accent: #00c8f0; --accent2: #0055ff; --green: #00f0a0; --text: #deeeff; --muted: #4a6080; --danger: #ff3366; } body, .gradio-container { background: var(--bg) !important; font-family: 'DM Sans', sans-serif !important; color: var(--text) !important; } .gradio-container { max-width: 1080px !important; margin: 0 auto !important; padding: 0 20px 60px !important; } #header { padding: 44px 0 28px; border-bottom: 1px solid var(--border); margin-bottom: 32px; position: relative; } #header::after { content: ''; position: absolute; bottom: -1px; left: 0; right: 0; height: 1px; background: linear-gradient(90deg, var(--accent2), var(--accent), var(--green)); } .badges { display: flex; gap: 8px; margin-bottom: 14px; flex-wrap: wrap; } .badge { font-size: 10px; font-weight: 600; letter-spacing: 0.1em; text-transform: uppercase; padding: 3px 9px; border-radius: 4px; border: 1px solid; } .b-amd { color: #ff6030; border-color: #ff603030; background: #ff603010; } .b-rocm { color: var(--accent); border-color: #00c8f030; background: #00c8f008; } .b-lora { color: var(--green); border-color: #00f0a030; background: #00f0a008; } .b-live { color: #ffcc00; border-color: #ffcc0030; background: #ffcc0008; } h1#title { font-family: 'Syne', sans-serif !important; font-size: 42px !important; font-weight: 800 !important; letter-spacing: -0.03em !important; line-height: 1 !important; color: var(--text) !important; margin-bottom: 10px !important; } h1#title em { color: var(--accent); font-style: normal; } .subtitle { font-size: 14px; color: var(--muted); font-weight: 300; line-height: 1.6; max-width: 520px; } #stats { display: flex; border: 1px solid var(--border); border-radius: 12px; overflow: hidden; background: var(--surface); margin-bottom: 28px; } .stat { flex: 1; padding: 14px 16px; text-align: center; border-right: 1px solid var(--border); } .stat:last-child { border-right: none; } .sv { font-family: 'Syne', sans-serif; font-size: 20px; font-weight: 700; color: var(--accent); display: block; } .sl { font-size: 10px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.08em; } .dot { display: inline-block; width: 6px; height: 6px; border-radius: 50%; background: var(--green); margin-right: 4px; animation: blink 2s infinite; } @keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} } label span, .label-wrap span { font-family: 'DM Sans', sans-serif !important; font-size: 11px !important; font-weight: 500 !important; color: var(--muted) !important; text-transform: uppercase !important; letter-spacing: 0.07em !important; } textarea, input[type=text] { background: var(--surface2) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; color: var(--text) !important; font-family: 'DM Sans', sans-serif !important; font-size: 14px !important; line-height: 1.6 !important; transition: border-color 0.2s, box-shadow 0.2s !important; } textarea:focus, input[type=text]:focus { border-color: var(--accent) !important; box-shadow: 0 0 0 3px #00c8f012 !important; outline: none !important; } .section-label { font-size: 10px; font-weight: 600; letter-spacing: 0.12em; text-transform: uppercase; color: var(--muted); margin-bottom: 10px; display: flex; align-items: center; gap: 7px; } .section-label::before { content: ''; width: 5px; height: 5px; border-radius: 50%; background: var(--accent); display: inline-block; } button.lg.primary { background: linear-gradient(135deg, var(--accent2), var(--accent)) !important; border: none !important; border-radius: 10px !important; color: #fff !important; font-family: 'Syne', sans-serif !important; font-size: 14px !important; font-weight: 700 !important; letter-spacing: 0.04em !important; padding: 14px !important; width: 100% !important; margin-top: 14px !important; cursor: pointer !important; transition: opacity 0.2s, transform 0.15s !important; } button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; } .out-box textarea { background: var(--surface2) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; font-size: 14px !important; line-height: 1.8 !important; color: var(--text) !important; min-height: 280px !important; } .examples-holder table { background: var(--surface) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; overflow: hidden !important; } .examples-holder td, .examples-holder th { background: transparent !important; color: var(--text) !important; font-size: 13px !important; border-color: var(--border) !important; font-family: 'DM Sans', sans-serif !important; } .examples-holder tr:hover td { background: var(--surface2) !important; cursor: pointer; } #footer { margin-top: 44px; padding-top: 22px; border-top: 1px solid var(--border); display: flex; justify-content: space-between; align-items: center; flex-wrap: wrap; gap: 10px; } .fl { font-size: 12px; color: var(--muted); } .fl strong { color: var(--text); } .fr { display: flex; gap: 14px; } .flink { font-size: 12px; color: var(--accent); text-decoration: none; } """ with gr.Blocks(css=CSS, title="MedQA — AMD ROCm") as demo: gr.HTML("""
1.7BParameters
LoRAFine-tuning
193kTraining QA
MI300XAMD GPU
bf16Precision
""") with gr.Row(): with gr.Column(scale=1): gr.HTML('
Clinical Question
') question = gr.Textbox( label="", placeholder="e.g. A 45-year-old presents with sudden onset severe headache...", lines=4, ) gr.HTML('
Answer Options
') with gr.Row(): opa = gr.Textbox(label="Option A", placeholder="First option") opb = gr.Textbox(label="Option B", placeholder="Second option") with gr.Row(): opc = gr.Textbox(label="Option C", placeholder="Third option") opd = gr.Textbox(label="Option D", placeholder="Fourth option") btn = gr.Button("Analyze Question", variant="primary") with gr.Column(scale=1): gr.HTML('
AI Answer & Reasoning
') output = gr.Textbox( label="", placeholder="Answer and clinical explanation will appear here...", lines=14, elem_classes=["out-box"], ) gr.HTML('
Sample Questions — click any to load
') gr.Examples( examples=EXAMPLES, inputs=[question, opa, opb, opc, opd], label="", ) gr.HTML(""" """) btn.click(fn=answer, inputs=[question, opa, opb, opc, opd], outputs=output) if __name__ == "__main__": demo.launch() # ← CHANGE 4: no server_name/port/share for HF Spaces