Bio_ClinicalBERT_MIMIC_IV_death_in_30_prediction_IA3_ti

This model is designed to predict 30-day mortality upon hospital discharge. It is trained on discharge notes from the MIMIC-IV dataset, which comprises of open-sourced Electronic Health Records (EHRs). Model was trained on a novel tabular-infused IA3, whereby the pre-operative tabular features (e.g., patient demographics and insurance information) were used to initialize the newly introduced IA3 parameters.

Model Details

How to use model

from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("cja5553/Bio_ClinicalBERT_MIMIC_IV_death_in_30_prediction_IA3_ti")
model = AutoModelForSequenceClassification.from_pretrained("cja5553/Bio_ClinicalBERT_MIMIC_IV_death_in_30_prediction_IA3_ti")

Then you can use this function below to get one test point

import torch

def get_outcome(tokenizer, model, text, device="cuda:0", max_length=512):

    device = torch.device(device)
    model = model.to(device)
    model.eval()

    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
        padding="max_length"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)[0]  # (2,)

    probs = probs.detach().cpu().numpy()
    result = {
        "False": float(probs[0]),
        "True": float(probs[1])
    }

    return result

Questions?

Contact me at alba@wustl.edu

Downloads last month
2
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for cja5553/Bio_ClinicalBERT_MIMIC_IV_death_in_30_prediction_IA3_ti

Adapter
(18)
this model

Collection including cja5553/Bio_ClinicalBERT_MIMIC_IV_death_in_30_prediction_IA3_ti