asd-interpreter-merged

Clinical language interpreter for ASD fMRI connectivity reports.
Fine-tuned from Qwen/Qwen2.5-7B-Instruct on AMD MI300X (ROCm 7.0) using QLoRA, then merged to a single fp16 checkpoint.

Used live in the BrainConnect-ASD Space to generate patient-facing clinical summaries from gradient saliency scores produced by a 20-model LOSO GCN ensemble.


Model Details

Field Value
Base model Qwen/Qwen2.5-7B-Instruct
Fine-tuning method QLoRA (r=16, α=32, target: q/v projections)
Training hardware AMD MI300X · ROCm 7.0 · DigitalOcean
Parameters 8B (merged, fp16)
Context length 4096 tokens
License Apache 2.0

What It Does

Given a structured prompt containing:

  • Ensemble ASD probability p(ASD)
  • Per-model predictions from 20 LOSO site-blind GCN models
  • Network-level gradient saliency scores (7 Yeo networks: DMN, Salience, Frontoparietal, etc.)

The model outputs a clinical connectivity summary with:

  1. Overall impression and confidence level
  2. Which brain networks drove the prediction and why
  3. Site-invariance assessment (20/20 model consensus signals robustness)
  4. Recommended next steps for clinical review

Usage

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "Yatsuiii/asd-interpreter-merged"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")

prompt = """You are a clinical neuroscience AI. Write a concise clinical connectivity summary.

Patient data:
- p(ASD) = 0.847 (ensemble mean across 20 site-blind models)
- Model consensus: 17/20 models predict ASD
- Top network saliency: DMN=0.0041, Salience=0.0038, Frontoparietal=0.0029

Write a 3-paragraph clinical summary."""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=400, temperature=0.3, do_sample=True)
print(tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True))

Training Details

  • Dataset: Synthetic clinical summaries generated from ABIDE I gradient saliency outputs, manually curated for clinical tone and factual grounding
  • Fine-tuning: QLoRA via peft + trl SFTTrainer
  • Hardware: AMD MI300X (192GB HBM3), ROCm 7.0, PyTorch 2.5.1+rocm6.2
  • Epochs: 3 · Batch size: 4 · LR: 2e-4 · Warmup: 50 steps
  • Merge: LoRA adapter merged into base weights with peft.merge_and_unload()

Integration

This model runs as a vLLM endpoint (served via rocm/vllm) and is queried by the BrainConnect-ASD Gradio Space after every inference run. If the vLLM server is unavailable, the Space falls back to a cached demo report.

Space → GCN ensemble inference → gradient saliency → structured prompt → this model → clinical report

Limitations

  • Trained on synthetic data derived from ABIDE I — not validated on real clinical populations
  • Not a medical device. Outputs are for research and demonstration purposes only.
  • Performance degrades on atlases other than CC200 (saliency prompt was optimized for CC200 → Yeo-7 mapping)

Citation

If you use this model or the BrainConnect-ASD pipeline, please cite:

BrainConnect-ASD — AMD Developer Hackathon 2026
Raghav Aryen · lablab.ai · AMD MI300X
https://huggingface.co/spaces/lablab-ai-amd-developer-hackathon/BrainConnect-ASD
Downloads last month
-
Safetensors
Model size
8B params
Tensor type
F16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for lablab-ai-amd-developer-hackathon/asd-interpreter-merged

Base model

Qwen/Qwen2.5-7B
Finetuned
(3277)
this model