| |
|
|
| import os |
| import traceback |
| import pandas as pd |
| import torch |
| import gradio as gr |
| from transformers import ( |
| logging, |
| AutoProcessor, |
| AutoTokenizer, |
| AutoModelForImageTextToText |
| ) |
| from sklearn.model_selection import train_test_split |
|
|
| |
| logging.set_verbosity_error() |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| if not HF_TOKEN: |
| raise RuntimeError("Missing HF_TOKEN in env vars – set it under Space Settings → Secrets") |
| MODEL_ID = "google/gemma-3n-e2b-it" |
|
|
| |
| processor = AutoProcessor.from_pretrained( |
| MODEL_ID, trust_remote_code=True, token=HF_TOKEN |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_ID, trust_remote_code=True, token=HF_TOKEN |
| ) |
|
|
| |
| def generate_and_export(): |
| try: |
| |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_ID, |
| trust_remote_code=True, |
| token=HF_TOKEN, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ) |
| device = next(model.parameters()).device |
|
|
| |
| def to_soap(text: str) -> str: |
| inputs = processor.apply_chat_template( |
| [ |
| {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]}, |
| {"role":"user", "content":[{"type":"text","text":text}]} |
| ], |
| add_generation_prompt=True, |
| tokenize=True, |
| return_tensors="pt", |
| return_dict=True |
| ).to(device) |
| out = model.generate( |
| **inputs, |
| max_new_tokens=400, |
| do_sample=True, |
| top_p=0.95, |
| temperature=0.1, |
| pad_token_id=processor.tokenizer.eos_token_id, |
| use_cache=False |
| ) |
| prompt_len = inputs["input_ids"].shape[-1] |
| return processor.batch_decode( |
| out[:, prompt_len:], skip_special_tokens=True |
| )[0].strip() |
|
|
| |
| docs, gts = [], [] |
| for i in range(1, 21): |
| doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.") |
| docs.append(doc) |
| gts.append(to_soap(doc)) |
| if i % 5 == 0: |
| torch.cuda.empty_cache() |
|
|
| |
| df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts}) |
| train_df, test_df = train_test_split(df, test_size=5, random_state=42) |
|
|
| os.makedirs("outputs", exist_ok=True) |
|
|
| |
| train_preds = [to_soap(d) for d in train_df["doc_note"]] |
| inf = train_df.reset_index(drop=True).copy() |
| inf["id"] = inf.index + 1 |
| inf["predicted_soap"] = train_preds |
| inf[["id","ground_truth_soap","predicted_soap"]].to_csv( |
| "outputs/inference.tsv", sep="\t", index=False |
| ) |
|
|
| |
| test_preds = [to_soap(d) for d in test_df["doc_note"]] |
| pd.DataFrame({ |
| "id": range(1, len(test_preds) + 1), |
| "predicted_soap": test_preds |
| }).to_csv("outputs/eval.csv", index=False) |
|
|
| |
| return ( |
| "✅ Done with 20 notes (15 train / 5 test)!", |
| "outputs/inference.tsv", |
| "outputs/eval.csv" |
| ) |
|
|
| except Exception as e: |
| traceback.print_exc() |
| return (f"❌ Error: {e}", None, None) |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Gemma‑3n SOAP Generator 🩺") |
| btn = gr.Button("Generate & Export 20 Notes") |
| status = gr.Textbox(interactive=False, label="Status") |
| inf_file = gr.File(label="Download inference.tsv") |
| eval_file= gr.File(label="Download eval.csv") |
|
|
| btn.click( |
| fn=generate_and_export, |
| inputs=None, |
| outputs=[status, inf_file, eval_file] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|