PsychGNN Masked-Edge Imputation Model

Model summary

This repository contains a heterogeneous graph neural network trained to recover masked SNP-disorder links within a psychiatric cross-disorder graph.

The model is intended for variant-level research use. It does not predict patient-level diagnosis, prognosis, or treatment response.

Intended task

The training and evaluation task is disorder-conditional masked-edge imputation:

  • start from a graph containing SNP, gene, and disorder nodes
  • remove a subset of observed SNP-disorder links from message passing
  • retain the same SNP's remaining cross-disorder context where available
  • predict whether the hidden SNP-disorder link should exist

This task answers a specific question:

given a psychiatric variant already represented in the graph, can the model recover missing cross-disorder links from the rest of the graph structure?

Data provenance

The checkpoint was trained on:

The harmonized dataset was derived from public OpenMed / PGC Hugging Face repositories, including:

Scope

The underlying graph contains 11 modeled disorder groups:

  • ADHD
  • Anxiety
  • Autism
  • Bipolar disorder
  • Borderline personality disorder
  • Eating disorders
  • Major depressive disorder
  • Obsessive-compulsive disorder
  • Post-traumatic stress disorder
  • Schizophrenia
  • Substance use

At inference time the checkpoint can score any of these disorders for variants present in the graph artifact.

Architecture

The model is a custom heterogeneous GraphSAGE-style network over three node types:

  • SNP nodes
  • gene nodes
  • disorder nodes

Message passing uses:

  • SNP self-updates
  • SNP-to-gene aggregation
  • gene self-updates
  • gene-to-SNP aggregation
  • disorder self-updates
  • SNP-to-disorder aggregation

This release uses no disorder-disorder edges.

Decoder heads

The checkpoint contains two task heads:

  • a bilinear link decoder for SNP-disorder link scoring
  • an effect-size regression head for predicting a normalized effect estimate on positive edges

Training configuration

Best hyperparameters:

  • hidden dimension: 192
  • layers: 3
  • dropout: 0.2
  • learning rate: 7.5e-4
  • weight decay: 1e-5
  • negative sampling ratio: 1

Checkpoint metadata:

  • SNP feature dimension: 7
  • gene feature dimension: 4
  • disorder feature dimension: 5
  • effect normalization mean: -0.001849
  • effect normalization std: 0.094672

Graph context

Graph metadata for this release:

  • variants: 18,979
  • genes: 1,205
  • disorders: 11
  • SNP-disorder edges: 22,687
  • SNP-gene edges: 65,634
  • disorder-disorder edges: 0
  • GWS threshold for graph edge construction: 5e-8
  • SNP-gene positional window: 100,000 bp

Evaluation

Primary masked-edge benchmark:

  • test AUROC: 0.9805
  • test AP: 0.9818
  • effect-size Pearson r: 0.9379
  • best validation AUROC: 0.9759

Per-disorder benchmark coverage:

Disorder AUROC AP Test edges
ADHD 0.9707 0.9698 56
Anxiety 0.9907 0.9920 350
Bipolar 0.9929 0.9919 544
MDD 0.9899 0.9898 244
Schizophrenia 0.9623 0.9674 604

Disorders not included in the masked-edge metric table because the current graph did not provide enough stable maskable positives for this benchmark:

  • Autism
  • Borderline personality disorder
  • Eating disorders
  • Obsessive-compulsive disorder
  • Post-traumatic stress disorder
  • Substance use

Baseline comparison

Baseline Test AUROC Test AP
Disorder prevalence 0.5000 0.5000
Variant degree 0.5661 0.6184
Additive prior 0.5430 0.5402
Low-rank SVD 0.5994 0.6413

Inputs and outputs

Inputs

The checkpoint expects:

  • SNP feature matrix
  • gene feature matrix
  • disorder feature matrix
  • SNP-gene edge index
  • SNP-disorder edge index
  • variant and disorder mappings

These are provided by the associated public graph artifact.

Outputs

For a scored (variant, disorder) pair, the model produces:

  • a link score indicating whether the SNP-disorder edge should exist under the masked-edge task
  • an effect-size prediction on the normalized training scale

The primary output of this release is the masked-edge link score.

How to use

Minimal checkpoint loading:

import torch
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(
    "lighteternal/psychgnn-masked-edge-imputation-model",
    "model.pt",
    repo_type="model",
)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)

print(checkpoint["hyperparams"])
print(checkpoint["feature_dims"])
print(checkpoint["report"]["masked_split"]["test_auroc"])

To run inference, instantiate a heterogeneous GraphSAGE-style model matching the architecture above, load checkpoint["state_dict"], and score variant-disorder pairs against lighteternal/psychgnn-psychiatric-graph.

Files in this repository

  • model.pt
  • evaluation_report.json

Limitations

  • This model is evaluated on a masked-edge recovery task, not on patient outcomes.
  • The headline benchmark covers only disorders with enough maskable positive edges under the current graph construction.
  • The graph uses a strict GWS edge definition; this reduces benchmark coverage for some disorders.
  • The checkpoint should not be interpreted as evidence of robust generalization to completely unseen disorders.
  • The effect-size head is trained on harmonized summary-statistics edges and does not constitute a causal estimate.

Appropriate use

Reasonable uses:

  • cross-disorder variant prioritization
  • exploratory pleiotropy analysis
  • follow-up prioritization for psychiatric genetics studies
  • downstream research tooling built on the published graph

Inappropriate uses:

  • clinical decision-making
  • diagnosis or screening
  • patient-level risk prediction
  • treatment selection
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Space using lighteternal/psychgnn-masked-edge-imputation-model 1