Add format section explaining safetensors compatibility
Browse files
README.md
CHANGED
|
@@ -66,14 +66,23 @@ Pre-trained [TabDPT](https://github.com/JesseCresswell/tfm-mia) model weights tr
|
|
| 66 |
| CTR Correlation | 0.830 | 0.830 | 0.827 |
|
| 67 |
| CTR R² | 0.726 | 0.730 | 0.725 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
## Usage
|
| 70 |
|
| 71 |
```python
|
| 72 |
from tabdpt import TabDPTClassifier
|
| 73 |
from huggingface_hub import hf_hub_download
|
| 74 |
|
|
|
|
| 75 |
path = hf_hub_download("dwahdany/TabDPT", "production_seed42.safetensors")
|
|
|
|
|
|
|
| 76 |
clf = TabDPTClassifier(model_weight_path=path)
|
| 77 |
clf.fit(X_train, y_train)
|
| 78 |
preds = clf.predict(X_test)
|
| 79 |
```
|
|
|
|
|
|
|
|
|
| 66 |
| CTR Correlation | 0.830 | 0.830 | 0.827 |
|
| 67 |
| CTR R² | 0.726 | 0.730 | 0.725 |
|
| 68 |
|
| 69 |
+
## Format
|
| 70 |
+
|
| 71 |
+
These checkpoints were converted from PyTorch Lightning `.ckpt` files (which include optimizer state, ~295MB each) to SafeTensors format (model weights only, ~103MB each). This is the same format used by the official `Layer6/TabDPT` release. The `tabdpt` package natively loads SafeTensors via the `model_weight_path` argument — no extra conversion needed.
|
| 72 |
+
|
| 73 |
## Usage
|
| 74 |
|
| 75 |
```python
|
| 76 |
from tabdpt import TabDPTClassifier
|
| 77 |
from huggingface_hub import hf_hub_download
|
| 78 |
|
| 79 |
+
# Download once (cached afterwards)
|
| 80 |
path = hf_hub_download("dwahdany/TabDPT", "production_seed42.safetensors")
|
| 81 |
+
|
| 82 |
+
# Use exactly like the default model
|
| 83 |
clf = TabDPTClassifier(model_weight_path=path)
|
| 84 |
clf.fit(X_train, y_train)
|
| 85 |
preds = clf.predict(X_test)
|
| 86 |
```
|
| 87 |
+
|
| 88 |
+
Works identically with `TabDPTRegressor`.
|