Gemma 4 Clinical Trial Endpoint Extractor

A fine-tuned Gemma 4 E4B-it model (LoRA adapter) for extracting structured endpoint information from clinical trial text. The model takes unstructured endpoint descriptions and outputs structured JSON following a predefined schema.

Model Details

Property Value
Base Model google/gemma-4-E4B-it
Method QLoRA (4-bit NF4 quantization + LoRA rank 16)
Trainable Parameters 42.4M (0.85% of 5B total)
Training Data 1,558 clinical trial endpoint samples
Data Sources ClinicalTrials.gov, EU CTR (EudraCT), ChiCTR (Chinese Clinical Trials)
Final Eval Loss 0.3006
Final Token Accuracy 94.07%
Training Time ~74 minutes on NVIDIA RTX A6000 (48GB)
License Apache 2.0

Training Results

Epoch Eval Loss Token Accuracy
1 0.3493 93.23%
2 0.3024 94.01%
3 0.3006 94.07%

Output Schema

The model outputs JSON with the following structure:

{
  "endpoints": [
    {
      "endpoint_name_standardized": "string | null",
      "measurement_of": "string | null",
      "measurement_type": "continuous | binary | time-to-event | ordinal | null",
      "metric_type": "string | null",
      "timeframe": "string | null",
      "measurement_method": "string | null",
      "evaluation_criteria": "string | null",
      "unit": "string | null",
      "population": "string | null",
      "is_composite": "boolean",
      "components": "[]"
    }
  ]
}

Field Definitions

  • endpoint_name_standardized: Normalized endpoint name (e.g., ORR -> Objective Response Rate)
  • measurement_of: Underlying clinical concept being measured
  • measurement_type: One of continuous, binary, time-to-event, ordinal
  • metric_type: Statistical metric (mean, median, proportion, hazard ratio, etc.)
  • timeframe: Assessment window (e.g., "24 weeks", "Up to 5 years")
  • measurement_method: Physical tool/technique (CT scan, MRI, blood test) - NOT scoring systems
  • evaluation_criteria: Scoring system/guideline (RECIST v1.1, WHO criteria) - NOT measurement tools
  • unit: Measurement unit (%, mL, ms, ng/mL, etc.)
  • population: Study population if specified (ITT, safety population, etc.)
  • is_composite: Whether endpoint combines multiple events
  • components: List of component events for composite endpoints

Usage

With PEFT + Transformers

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

# IMPORTANT: Monkey-patch for Gemma 4 compatibility with PEFT
import torch.nn as nn
from transformers.models.gemma4 import modeling_gemma4

class PatchedClippableLinear(nn.Linear):
    def __init__(self, config, in_features, out_features):
        nn.Linear.__init__(self, in_features, out_features, bias=False)
        self.use_clipped_linears = getattr(config, "use_clipped_linears", False)
        if self.use_clipped_linears:
            self.register_buffer("input_min", torch.tensor(-float("inf")))
            self.register_buffer("input_max", torch.tensor(float("inf")))
            self.register_buffer("output_min", torch.tensor(-float("inf")))
            self.register_buffer("output_max", torch.tensor(float("inf")))
    def forward(self, x):
        if self.use_clipped_linears:
            x = torch.clamp(x, self.input_min, self.input_max)
        out = nn.Linear.forward(self, x)
        if self.use_clipped_linears:
            out = torch.clamp(out, self.output_min, self.output_max)
        return out

modeling_gemma4.Gemma4ClippableLinear = PatchedClippableLinear

# Load base model in 4-bit
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-E4B-it",
    quantization_config=quantization_config,
    device_map="auto",
    dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-4-E4B-it")

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "Shubh-0789/gemma4-clinical-endpoint-extractor")
model.eval()

# Prepare input
system_prompt = "You are a clinical trial endpoint extraction system. Extract structured endpoint information from clinical trial text. Return ONLY valid JSON."

endpoint_text = "Progression-Free Survival (PFS) assessed by RECIST v1.1 using CT scan at 24 weeks"

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": f"Extract all endpoint fields from this clinical trial text. Return ONLY a JSON.\n\nText:\n{endpoint_text}\n\nOutput (JSON only):"},
]

input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=False)

response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)

Example Outputs

Single Endpoint

Input: Number of participants achieving CR and PR by 4 months as per IWG 2023 criteria | [Time Frame: Through 4 months after starting treatment]

Output:

{
  "endpoints": [
    {
      "endpoint_name_standardized": "Number of participants achieving CR and PR",
      "measurement_of": "Objective Response Rate",
      "measurement_type": "binary",
      "metric_type": "count",
      "timeframe": "Through 4 months after starting treatment",
      "measurement_method": null,
      "evaluation_criteria": "IWG 2023 criteria",
      "unit": "number",
      "population": null,
      "is_composite": true,
      "components": ["CR", "PR"]
    }
  ]
}

Composite Biomarker Endpoint

Input: Biomarkers (Phase II) | Integrated biomarker endpoints include: PTEN immunohistochemistry, estrogen receptor and progesterone receptor, whole exome sequencing, ribonucleic acid sequencing. | [Time Frame: Up to 5 years]

Output:

{
  "endpoints": [
    {
      "endpoint_name_standardized": "Integrated Biomarker Association with PFS",
      "measurement_of": "prognostic association of integrated biomarkers with PFS",
      "measurement_type": "continuous",
      "metric_type": "hazard ratio",
      "timeframe": "Up to 5 years",
      "measurement_method": "Integrated biomarker analysis including PTEN immunohistochemistry, estrogen receptor and progesterone receptor, whole exome sequencing, ribonucleic acid sequencing",
      "evaluation_criteria": "Proportional hazards models",
      "unit": null,
      "population": null,
      "is_composite": true,
      "components": ["PTEN immunohistochemistry", "estrogen receptor", "progesterone receptor", "whole exome sequencing", "ribonucleic acid sequencing"]
    }
  ]
}

Minimal Input

Input: Body height

Output:

{
  "endpoints": [
    {
      "endpoint_name_standardized": "Body height",
      "measurement_of": "height",
      "measurement_type": "continuous",
      "metric_type": null,
      "timeframe": null,
      "measurement_method": "Stadiometer",
      "evaluation_criteria": null,
      "unit": "cm",
      "population": null,
      "is_composite": false,
      "components": []
    }
  ]
}

Training Configuration

Parameter Value
LoRA Rank (r) 16
LoRA Alpha 16
LoRA Dropout 0.05
Target Modules q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
Learning Rate 5e-5
LR Scheduler Cosine
Epochs 3
Batch Size 1 (x4 gradient accumulation)
Max Sequence Length 2048
Optimizer AdamW 8-bit
Quantization 4-bit NF4 with double quantization
Gradient Checkpointing Enabled
Warmup Steps 50
Total Steps 1,170

Training Data

The model was trained on 1,558 clinical trial endpoint samples (with 195 validation) sourced from three international registries:

Source Endpoints
ClinicalTrials.gov (USA) ~1,100
ChiCTR (China) ~500
EU CTR (Europe) ~350

Ground-truth labels were generated using Qwen 3.6-plus and the dataset covers:

  • Single and multiple endpoint extraction
  • Composite endpoint detection (19% of samples)
  • Various measurement types (continuous, binary, time-to-event, ordinal)
  • Multiple therapeutic areas (oncology, cardiology, neurology, etc.)

Hardware Requirements

Setup VRAM
QLoRA inference (4-bit) ~10 GB
QLoRA training ~28 GB

Limitations

  • Ground truth labels were generated by an LLM (Qwen 3.6-plus), not manually annotated
  • May occasionally hallucinate a second endpoint for single-endpoint inputs
  • measurement_method may be inferred even when not explicitly stated in the text
  • Primarily trained on English-language endpoint descriptions

Citation

If you use this model, please cite:

@misc{gemma4-clinical-endpoint-extractor,
  title={Gemma 4 Clinical Trial Endpoint Extractor},
  author={Shubhanshu Yadav},
  year={2026},
  publisher={Hugging Face},
  url={https://huggingface.co/Shubh-0789/gemma4-clinical-endpoint-extractor}
}
Downloads last month
48
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Shubh-0789/gemma4-clinical-endpoint-extractor

Adapter
(64)
this model

Evaluation results