Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,32 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
import gradio as gr
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
BASE_MODEL = "Qwen/
|
| 10 |
-
ADAPTER_PATH = "./outputs
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
EXAMPLES = [
|
| 15 |
["Which artery is occluded in inferior MI with ST elevation in II, III, aVF?",
|
|
@@ -29,20 +46,29 @@ EXAMPLES = [
|
|
| 29 |
def answer(question, opa, opb, opc, opd):
|
| 30 |
if not question.strip():
|
| 31 |
return "Please enter a question."
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
CSS = """
|
| 48 |
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap');
|
|
@@ -72,7 +98,6 @@ body, .gradio-container {
|
|
| 72 |
padding: 0 20px 60px !important;
|
| 73 |
}
|
| 74 |
|
| 75 |
-
/* Header */
|
| 76 |
#header {
|
| 77 |
padding: 44px 0 28px;
|
| 78 |
border-bottom: 1px solid var(--border);
|
|
@@ -105,7 +130,6 @@ h1#title {
|
|
| 105 |
h1#title em { color: var(--accent); font-style: normal; }
|
| 106 |
.subtitle { font-size: 14px; color: var(--muted); font-weight: 300; line-height: 1.6; max-width: 520px; }
|
| 107 |
|
| 108 |
-
/* Stats */
|
| 109 |
#stats {
|
| 110 |
display: flex; border: 1px solid var(--border);
|
| 111 |
border-radius: 12px; overflow: hidden;
|
|
@@ -118,7 +142,6 @@ h1#title em { color: var(--accent); font-style: normal; }
|
|
| 118 |
.dot { display: inline-block; width: 6px; height: 6px; border-radius: 50%; background: var(--green); margin-right: 4px; animation: blink 2s infinite; }
|
| 119 |
@keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} }
|
| 120 |
|
| 121 |
-
/* Inputs */
|
| 122 |
label span, .label-wrap span {
|
| 123 |
font-family: 'DM Sans', sans-serif !important;
|
| 124 |
font-size: 11px !important; font-weight: 500 !important;
|
|
@@ -140,7 +163,6 @@ textarea:focus, input[type=text]:focus {
|
|
| 140 |
outline: none !important;
|
| 141 |
}
|
| 142 |
|
| 143 |
-
/* Section labels */
|
| 144 |
.section-label {
|
| 145 |
font-size: 10px; font-weight: 600;
|
| 146 |
letter-spacing: 0.12em; text-transform: uppercase;
|
|
@@ -152,7 +174,6 @@ textarea:focus, input[type=text]:focus {
|
|
| 152 |
background: var(--accent); display: inline-block;
|
| 153 |
}
|
| 154 |
|
| 155 |
-
/* Button */
|
| 156 |
button.lg.primary {
|
| 157 |
background: linear-gradient(135deg, var(--accent2), var(--accent)) !important;
|
| 158 |
border: none !important; border-radius: 10px !important;
|
|
@@ -165,7 +186,6 @@ button.lg.primary {
|
|
| 165 |
}
|
| 166 |
button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
|
| 167 |
|
| 168 |
-
/* Output */
|
| 169 |
.out-box textarea {
|
| 170 |
background: var(--surface2) !important;
|
| 171 |
border: 1px solid var(--border) !important;
|
|
@@ -174,7 +194,6 @@ button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px)
|
|
| 174 |
color: var(--text) !important; min-height: 280px !important;
|
| 175 |
}
|
| 176 |
|
| 177 |
-
/* Examples */
|
| 178 |
.examples-holder table {
|
| 179 |
background: var(--surface) !important;
|
| 180 |
border: 1px solid var(--border) !important;
|
|
@@ -187,7 +206,6 @@ button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px)
|
|
| 187 |
}
|
| 188 |
.examples-holder tr:hover td { background: var(--surface2) !important; cursor: pointer; }
|
| 189 |
|
| 190 |
-
/* Footer */
|
| 191 |
#footer {
|
| 192 |
margin-top: 44px; padding-top: 22px;
|
| 193 |
border-top: 1px solid var(--border);
|
|
@@ -206,7 +224,7 @@ with gr.Blocks(css=CSS, title="MedQA β AMD ROCm") as demo:
|
|
| 206 |
<div id="header">
|
| 207 |
<div class="badges">
|
| 208 |
<span class="badge b-amd">AMD MI300X</span>
|
| 209 |
-
<span class="badge b-rocm">ROCm
|
| 210 |
<span class="badge b-lora">LoRA Fine-tuned</span>
|
| 211 |
<span class="badge b-live"><span class="dot"></span>Live Inference</span>
|
| 212 |
</div>
|
|
@@ -217,7 +235,7 @@ with gr.Blocks(css=CSS, title="MedQA β AMD ROCm") as demo:
|
|
| 217 |
</p>
|
| 218 |
</div>
|
| 219 |
<div id="stats">
|
| 220 |
-
<div class="stat"><span class="sv">1.
|
| 221 |
<div class="stat"><span class="sv">LoRA</span><span class="sl">Fine-tuning</span></div>
|
| 222 |
<div class="stat"><span class="sv">193k</span><span class="sl">Training QA</span></div>
|
| 223 |
<div class="stat"><span class="sv">MI300X</span><span class="sl">AMD GPU</span></div>
|
|
@@ -256,17 +274,17 @@ with gr.Blocks(css=CSS, title="MedQA β AMD ROCm") as demo:
|
|
| 256 |
examples=EXAMPLES,
|
| 257 |
inputs=[question, opa, opb, opc, opd],
|
| 258 |
label="",
|
| 259 |
-
|
| 260 |
|
| 261 |
gr.HTML("""
|
| 262 |
<div id="footer">
|
| 263 |
<div class="fl">
|
| 264 |
Built on <strong>AMD Developer Cloud</strong> Β·
|
| 265 |
-
Model: <strong>
|
| 266 |
Dataset: <strong>MedMCQA</strong>
|
| 267 |
</div>
|
| 268 |
<div class="fr">
|
| 269 |
-
<a class="flink" href="https://github.com" target="_blank">GitHub β</a>
|
| 270 |
<a class="flink" href="https://lablab.ai" target="_blank">lablab.ai β</a>
|
| 271 |
<a class="flink" href="https://cloud.amd.com" target="_blank">AMD Cloud β</a>
|
| 272 |
</div>
|
|
@@ -276,4 +294,4 @@ with gr.Blocks(css=CSS, title="MedQA β AMD ROCm") as demo:
|
|
| 276 |
btn.click(fn=answer, inputs=[question, opa, opb, opc, opd], outputs=output)
|
| 277 |
|
| 278 |
if __name__ == "__main__":
|
| 279 |
-
demo.launch(
|
|
|
|
| 1 |
import os
|
| 2 |
+
import torch
|
| 3 |
import gradio as gr
|
| 4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 5 |
+
from peft import PeftModel
|
| 6 |
+
|
| 7 |
+
# β CHANGE 1: ROCm env vars removed
|
| 8 |
+
|
| 9 |
+
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
| 10 |
+
ADAPTER_PATH = "HK2184/medqa-qwen3-lora" # β CHANGE 2: HF Hub instead of ./outputs
|
| 11 |
+
|
| 12 |
+
print("Loading tokenizer...")
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
| 14 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 15 |
+
tokenizer.padding_side = "left"
|
| 16 |
+
|
| 17 |
+
print("Loading model...")
|
| 18 |
+
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # β CHANGE 3: auto dtype
|
| 19 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 20 |
+
BASE_MODEL,
|
| 21 |
+
dtype=DTYPE,
|
| 22 |
+
device_map="auto",
|
| 23 |
+
trust_remote_code=True,
|
| 24 |
+
low_cpu_mem_usage=True,
|
| 25 |
+
)
|
| 26 |
+
model = PeftModel.from_pretrained(base, ADAPTER_PATH)
|
| 27 |
+
model = model.merge_and_unload()
|
| 28 |
+
model.eval()
|
| 29 |
+
print("Ready!")
|
| 30 |
|
| 31 |
EXAMPLES = [
|
| 32 |
["Which artery is occluded in inferior MI with ST elevation in II, III, aVF?",
|
|
|
|
| 46 |
def answer(question, opa, opb, opc, opd):
|
| 47 |
if not question.strip():
|
| 48 |
return "Please enter a question."
|
| 49 |
+
if not all([opa.strip(), opb.strip(), opc.strip(), opd.strip()]):
|
| 50 |
+
return "Please fill in all four options."
|
| 51 |
+
prompt = (
|
| 52 |
+
f"### Question:\n{question}\n\n"
|
| 53 |
+
f"### Options:\n"
|
| 54 |
+
f"A) {opa}\nB) {opb}\nC) {opc}\nD) {opd}\n\n"
|
| 55 |
+
f"### Answer:\n"
|
| 56 |
+
)
|
| 57 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
out = model.generate(
|
| 60 |
+
**inputs,
|
| 61 |
+
max_new_tokens=200,
|
| 62 |
+
do_sample=True,
|
| 63 |
+
temperature=0.7,
|
| 64 |
+
top_p=0.9,
|
| 65 |
+
top_k=50,
|
| 66 |
+
repetition_penalty=1.3,
|
| 67 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 68 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 69 |
+
)
|
| 70 |
+
new = out[0][inputs["input_ids"].shape[-1]:]
|
| 71 |
+
return tokenizer.decode(new, skip_special_tokens=True)
|
| 72 |
|
| 73 |
CSS = """
|
| 74 |
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Sans:wght@300;400;500&display=swap');
|
|
|
|
| 98 |
padding: 0 20px 60px !important;
|
| 99 |
}
|
| 100 |
|
|
|
|
| 101 |
#header {
|
| 102 |
padding: 44px 0 28px;
|
| 103 |
border-bottom: 1px solid var(--border);
|
|
|
|
| 130 |
h1#title em { color: var(--accent); font-style: normal; }
|
| 131 |
.subtitle { font-size: 14px; color: var(--muted); font-weight: 300; line-height: 1.6; max-width: 520px; }
|
| 132 |
|
|
|
|
| 133 |
#stats {
|
| 134 |
display: flex; border: 1px solid var(--border);
|
| 135 |
border-radius: 12px; overflow: hidden;
|
|
|
|
| 142 |
.dot { display: inline-block; width: 6px; height: 6px; border-radius: 50%; background: var(--green); margin-right: 4px; animation: blink 2s infinite; }
|
| 143 |
@keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} }
|
| 144 |
|
|
|
|
| 145 |
label span, .label-wrap span {
|
| 146 |
font-family: 'DM Sans', sans-serif !important;
|
| 147 |
font-size: 11px !important; font-weight: 500 !important;
|
|
|
|
| 163 |
outline: none !important;
|
| 164 |
}
|
| 165 |
|
|
|
|
| 166 |
.section-label {
|
| 167 |
font-size: 10px; font-weight: 600;
|
| 168 |
letter-spacing: 0.12em; text-transform: uppercase;
|
|
|
|
| 174 |
background: var(--accent); display: inline-block;
|
| 175 |
}
|
| 176 |
|
|
|
|
| 177 |
button.lg.primary {
|
| 178 |
background: linear-gradient(135deg, var(--accent2), var(--accent)) !important;
|
| 179 |
border: none !important; border-radius: 10px !important;
|
|
|
|
| 186 |
}
|
| 187 |
button.lg.primary:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
|
| 188 |
|
|
|
|
| 189 |
.out-box textarea {
|
| 190 |
background: var(--surface2) !important;
|
| 191 |
border: 1px solid var(--border) !important;
|
|
|
|
| 194 |
color: var(--text) !important; min-height: 280px !important;
|
| 195 |
}
|
| 196 |
|
|
|
|
| 197 |
.examples-holder table {
|
| 198 |
background: var(--surface) !important;
|
| 199 |
border: 1px solid var(--border) !important;
|
|
|
|
| 206 |
}
|
| 207 |
.examples-holder tr:hover td { background: var(--surface2) !important; cursor: pointer; }
|
| 208 |
|
|
|
|
| 209 |
#footer {
|
| 210 |
margin-top: 44px; padding-top: 22px;
|
| 211 |
border-top: 1px solid var(--border);
|
|
|
|
| 224 |
<div id="header">
|
| 225 |
<div class="badges">
|
| 226 |
<span class="badge b-amd">AMD MI300X</span>
|
| 227 |
+
<span class="badge b-rocm">ROCm 7.2</span>
|
| 228 |
<span class="badge b-lora">LoRA Fine-tuned</span>
|
| 229 |
<span class="badge b-live"><span class="dot"></span>Live Inference</span>
|
| 230 |
</div>
|
|
|
|
| 235 |
</p>
|
| 236 |
</div>
|
| 237 |
<div id="stats">
|
| 238 |
+
<div class="stat"><span class="sv">1.7B</span><span class="sl">Parameters</span></div>
|
| 239 |
<div class="stat"><span class="sv">LoRA</span><span class="sl">Fine-tuning</span></div>
|
| 240 |
<div class="stat"><span class="sv">193k</span><span class="sl">Training QA</span></div>
|
| 241 |
<div class="stat"><span class="sv">MI300X</span><span class="sl">AMD GPU</span></div>
|
|
|
|
| 274 |
examples=EXAMPLES,
|
| 275 |
inputs=[question, opa, opb, opc, opd],
|
| 276 |
label="",
|
| 277 |
+
)
|
| 278 |
|
| 279 |
gr.HTML("""
|
| 280 |
<div id="footer">
|
| 281 |
<div class="fl">
|
| 282 |
Built on <strong>AMD Developer Cloud</strong> Β·
|
| 283 |
+
Model: <strong>Qwen3-1.7B + LoRA</strong> Β·
|
| 284 |
Dataset: <strong>MedMCQA</strong>
|
| 285 |
</div>
|
| 286 |
<div class="fr">
|
| 287 |
+
<a class="flink" href="https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm" target="_blank">GitHub β</a>
|
| 288 |
<a class="flink" href="https://lablab.ai" target="_blank">lablab.ai β</a>
|
| 289 |
<a class="flink" href="https://cloud.amd.com" target="_blank">AMD Cloud β</a>
|
| 290 |
</div>
|
|
|
|
| 294 |
btn.click(fn=answer, inputs=[question, opa, opb, opc, opd], outputs=output)
|
| 295 |
|
| 296 |
if __name__ == "__main__":
|
| 297 |
+
demo.launch() # β CHANGE 4: no server_name/port/share for HF Spaces
|