File size: 2,434 Bytes
a112915
 
ea2e895
 
a112915
 
 
 
 
 
 
 
 
 
ea2e895
a112915
ea2e895
 
 
 
a112915
 
 
 
 
 
 
 
ea2e895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a112915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea2e895
a112915
 
 
 
ea2e895
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
---
library_name: pytorch
license: mit
pipeline_tag: unconditional-image-generation
tags:
- hdtree
- pytorch
- mnist
- single-cell
- clustering
---

# HDTree ICML Checkpoints

This repository hosts pretrained checkpoints for the model presented in the paper [HDTree: Generative Modeling of Cellular Hierarchies for Robust Lineage Inference](https://huggingface.co/papers/2506.23287).

HDTree is a generative modeling framework designed for robust lineage inference. It captures tree relationships within a hierarchical latent space using a unified hierarchical codebook and employs a quantized diffusion process to model continuous cell state transitions.

- **Code:** [https://github.com/zangzelin/code_HDTree_icml](https://github.com/zangzelin/code_HDTree_icml)
- **Project Page:** [https://zangzelin.github.io/code_HDTree_icml/](https://zangzelin.github.io/code_HDTree_icml/)

## Files

| File | Dataset | Configuration | Notes |
|---|---|---|---|
| `checkpoints/mnist/hdtree_mnist_best_epoch59_acc0.97570.pth` | MNIST | `configs/mnist.yaml` | Best MNIST checkpoint from the full run by checkpoint validation accuracy. |
| `checkpoints/limb/hdtree_limb_i10_epoch199_acc0.53921.pth` | Limb | `configs/limb.yaml` default | Limb sweep i10/default checkpoint. |

## Sample Usage

To validate a trained checkpoint using the official code, you can use the provided validation script:

```bash
# Example for MNIST
bash scripts/validate_checkpoint.sh mnist checkpoints/mnist/hdtree_mnist_best_epoch59_acc0.97570.pth
```

To compute reconstruction and log-likelihood with diffusion sampling, enable generation using the following command:

```bash
python main.py validate \
  -c configs/mnist.yaml \
  --model.init_args.ckpt_path=checkpoints/mnist/hdtree_mnist_best_epoch59_acc0.97570.pth \
  --model.init_args.training_str=step2_r \
  --model.init_args.gen_data_bool=True
```

## Reported Metrics

MNIST full run summary:

| ACC | DP | LP | NMI |
|---:|---:|---:|---:|
| 0.97310 | 0.93262 | 0.97310 | 0.92999 |

Limb i10 run summary (`batch_size=1000`, `K=10`, `exaggeration_lat=0.5`, `nu_lat=0.3`):

| ACC | DP | LP | NMI |
|---:|---:|---:|---:|
| 0.52860 | 0.41029 | 0.58370 | 0.49042 |

The included `logs/` files contain the original run outputs used to record these metrics.

## Download

```bash
pip install huggingface_hub
huggingface-cli download zelinzang/HDTree-ICML-checkpoints --local-dir .
```

## Checksums

See `SHA256SUMS`.